Compare commits

...

1 Commits

Author SHA1 Message Date
David Sanders
39aed69a33 feat: implement the Prompt API via localAIHandler
Assisted-by: Claude Opus 4.6
2026-04-03 23:15:18 -07:00
67 changed files with 4003 additions and 4 deletions

View File

@@ -611,6 +611,17 @@ source_set("electron_lib") {
]
}
if (enable_prompt_api) {
sources += [
"shell/browser/ai/proxying_ai_manager.cc",
"shell/browser/ai/proxying_ai_manager.h",
"shell/utility/ai/utility_ai_language_model.cc",
"shell/utility/ai/utility_ai_language_model.h",
"shell/utility/ai/utility_ai_manager.cc",
"shell/utility/ai/utility_ai_manager.h",
]
}
if (is_mac) {
# Disable C++ modules to resolve linking error when including MacOS SDK
# headers from third_party/electron_node/deps/uv/include/uv/darwin.h

View File

@@ -12,6 +12,7 @@ buildflag_header("buildflags") {
"ENABLE_PDF_VIEWER=$enable_pdf_viewer",
"ENABLE_ELECTRON_EXTENSIONS=$enable_electron_extensions",
"ENABLE_BUILTIN_SPELLCHECKER=$enable_builtin_spellchecker",
"ENABLE_PROMPT_API=$enable_prompt_api",
"OVERRIDE_LOCATION_PROVIDER=$enable_fake_location_provider",
]

View File

@@ -17,6 +17,9 @@ declare_args() {
# Enable Spellchecker support
enable_builtin_spellchecker = true
# Enable Prompt API support.
enable_prompt_api = true
# The version of Electron.
# Packagers and vendor builders should set this in gn args to avoid running
# the script that reads git tag.

View File

@@ -0,0 +1,85 @@
## Class: LanguageModel
> Implement local AI language models
Process: [Utility](../glossary.md#utility-process)
### `new LanguageModel(initialState)`
* `initialState` Object
* `contextUsage` number
* `contextWindow` number
> [!NOTE]
> Do not use this constructor directly outside of the class itself, as it will not be properly connected to the `localAIHandler`
### Static Methods
The `LanguageModel` class has the following static methods:
#### `LanguageModel.create(options)` _Experimental_
* `options` [LanguageModelCreateOptions](structures/language-model-create-options.md)
Returns `Promise<LanguageModel>`. Creates a new `LanguageModel` with the provided `options`.
#### `LanguageModel.availability([options])` _Experimental_
* `options` [LanguageModelCreateCoreOptions](structures/language-model-create-core-options.md) (optional)
Returns `Promise<string>`
Determines the availability of the language model and returns one of the following strings:
* `available`
* `downloadable`
* `downloading`
* `unavailable`
### Instance Properties
The following properties are available on instances of `LanguageModel`:
#### `languageModel.contextUsage` _Experimental_
A `number` representing how many tokens are currently in the context window.
#### `languageModel.contextWindow` _Experimental_
A `number` representing the size of the context window, in tokens.
### Instance Methods
The following methods are available on instances of `LanguageModel`:
#### `languageModel.prompt(input, options)` _Experimental_
* `input` [LanguageModelMessage[]](structures/language-model-message.md)
* `options` [LanguageModelPromptOptions](structures/language-model-prompt-options.md)
Returns `Promise<string> | Promise<import('stream/web').ReadableStream<string>>`. Prompt the model for a response.
#### `languageModel.append(input, options)` _Experimental_
* `input` [LanguageModelMessage[]](structures/language-model-message.md)
* `options` [LanguageModelAppendOptions](structures/language-model-append-options.md)
Returns `Promise<undefined>`. Append a message without prompting for a response.
#### `languageModel.measureContextUsage(input, options)` _Experimental_
* `input` [LanguageModelMessage[]](structures/language-model-message.md)
* `options` [LanguageModelPromptOptions](structures/language-model-prompt-options.md)
Returns `Promise<number>`. Measure how many tokens the input would use.
#### `languageModel.clone(options)` _Experimental_
* `options` [LanguageModelCloneOptions](structures/language-model-clone-options.md)
Returns `Promise<LanguageModel>`. Clones the `LanguageModel` such that the
context and initial prompt should be preserved.
#### `languageModel.destroy()` _Experimental_
Destroys the model, and any ongoing executions are aborted.

View File

@@ -0,0 +1,26 @@
# localAIHandler
> Proxy built-in AI APIs to a local LLM implementation
Process: [Utility](../glossary.md#utility-process)
This module is intended to be used by a script registered to a session via
[`ses.registerLocalAIHandler(handler)`](./session.md#sesregisterlocalaihandlerhandler-experimental)
## Methods
The `localAIHandler` module has the following methods:
#### `localAIHandler.setPromptAPIHandler(handler)` _Experimental_
* `handler` Function\<typeof [LanguageModel](language-model.md)\> | null
* `details` Object
* `webContentsId` Integer - The [unique id](web-contents.md#contentsid-readonly) of
the [WebContents](web-contents.md) calling the Prompt API.
* `securityOrigin` string - Origin of the page calling the Prompt API.
Sets the handler for new Prompt API binding requests from the renderer process. This happens
once per pair of `webContentsId` and `securityOrigin`. Clearing the handler by calling
`setPromptAPIHandler(null)` will prevent new Prompt API sessions from being started,
but will not invalidate existing ones. If you want to invalidate existing Prompt API sessions,
clear the local AI handler for the session using `ses.registerLocalAIHandler(null)`.

View File

@@ -1632,6 +1632,12 @@ This method clears more types of data and is more thorough than the
For more information, refer to Chromium's [`BrowsingDataRemover` interface][browsing-data-remover].
#### `ses.registerLocalAIHandler(handler)` _Experimental_
* `handler` [UtilityProcess](utility-process.md#class-utilityprocess) | null
Registers a local AI handler `UtilityProcess`. To clear the handler, call `registerLocalAIHandler(null)`, which will disconnect any existing Prompt API sessions and destroy any `LanguageModel` instances.
### Instance Properties
The following properties are available on instances of `Session`:

View File

@@ -0,0 +1,3 @@
# LanguageModelAppendOptions Object
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)

View File

@@ -0,0 +1,3 @@
# LanguageModelCloneOptions Object
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)

View File

@@ -0,0 +1,4 @@
# LanguageModelCreateCoreOptions Object
* `expectedInputs` [LanguageModelExpected[]](language-model-expected.md) (optional)
* `expectedOutputs` [LanguageModelExpected[]](language-model-expected.md) (optional)

View File

@@ -0,0 +1,4 @@
# LanguageModelCreateOptions Object extends `LanguageModelCreateCoreOptions`
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)
* `initialPrompts` [LanguageModelMessage[]](language-model-message.md) (optional)

View File

@@ -0,0 +1,7 @@
# LanguageModelExpected Object
* `type` string - Can be one of the following values:
* `text`
* `image`
* `audio`
* `languages` string[] (optional)

View File

@@ -0,0 +1,7 @@
# LanguageModelMessageContent Object
* `type` string - Can be one of the following values:
* `text`
* `image`
* `audio`
* `value` ArrayBuffer | string

View File

@@ -0,0 +1,8 @@
# LanguageModelMessage Object
* `role` string - Can be one of the following values:
* `system`
* `user`
* `assistant`
* `content` [LanguageModelMessageContent[]](language-model-message-content.md)
* `prefix` boolean (optional)

View File

@@ -0,0 +1,4 @@
# LanguageModelPromptOptions Object
* `responseConstraint` Object (optional)
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)

View File

@@ -0,0 +1,175 @@
---
title: Local AI Handler
description: Handle built-in AI APIs with a local LLM implementation
slug: local-ai-handler
hide_title: true
---
# Local AI Handler
> **This API is experimental.** It may change or be removed in future Electron releases.
Electron supports [Prompt API](https://github.com/webmachinelearning/prompt-api)
(`LanguageModel`) web API by letting you route calls to a local LLM running in a
[utility process](../api/utility-process.md). Web content calls
`LanguageModel.create()` and `LanguageModel.prompt()` like it would in any
browser, while your Electron app decides which model handles the request.
## How it works
The local AI handler architecture involves three processes:
1. **Main process** — creates `UtilityProcess`, and then registers it to handle
Prompt API calls for a given session via [`ses.registerLocalAIHandler()`](../api/session.md#sesregisterlocalaihandlerhandler-experimental).
2. **Utility process** — runs a script that calls
[`localAIHandler.setPromptAPIHandler()`](../api/local-ai-handler.md#localaihandlersetpromptapihandlerhandler-experimental)
to supply a `LanguageModel` subclass.
3. **Renderer process** — web content uses the standard `LanguageModel` API
(e.g. `LanguageModel.create()`, `model.prompt()`).
When a renderer calls the Prompt API, Electron proxies the request through the
main process to the registered utility process, which invokes your
`LanguageModel` implementation and sends the result back directly to the renderer.
## Prerequisites
The Prompt API Blink feature must be enabled on any `BrowserWindow` that will
use it with the `AIPromptAPI` feature. To enable multi-modal inputs, add the
`AIPromptAPIMultimodalInput` as well.
```js
const win = new BrowserWindow({
webPreferences: {
enableBlinkFeatures: 'AIPromptAPI'
}
})
```
## Quick start
### 1. Create the utility process script
The utility process script registers your `LanguageModel` subclass. The
handler function receives a `details` object with information about the
caller, and must return a class that extends `LanguageModel`.
```js title='ai-handler.js (Utility Process)'
const { localAIHandler, LanguageModel } = require('electron/utility')
localAIHandler.setPromptAPIHandler((details) => {
// details.webContentsId — ID of the calling WebContents
// details.securityOrigin — origin of the calling page
return class MyLanguageModel extends LanguageModel {
static async create (options) {
// options.signal - AbortSignal to cancel the creation of the model
// options.initialPrompts - initial prompts to pass to the language model
return new MyLanguageModel({
contextUsage: 0,
contextWindow: 4096
})
}
static async availability () {
// Return 'available', 'downloadable', 'downloading', or 'unavailable'
return 'available'
}
async prompt (input) {
// input is a string or LanguageModelMessage[]
// Return a string response from your model, or a ReadableStream
// to return a streaming response.
return 'This is a response from your local LLM!'
}
async clone () {
return new MyLanguageModel({
contextUsage: this.contextUsage,
contextWindow: this.contextWindow
})
}
destroy () {
// Clean up model resources
}
}
})
```
### 2. Register the handler in the main process
Fork the utility process and register it as the AI handler for a session:
```js title='main.js (Main Process)'
const { app, BrowserWindow, utilityProcess } = require('electron')
const path = require('node:path')
app.whenReady().then(() => {
// Fork the utility process running your AI handler script
const aiHandler = utilityProcess.fork(path.join(__dirname, 'ai-handler.js'))
// Create a window with the Prompt API enabled
const win = new BrowserWindow({
webPreferences: {
enableBlinkFeatures: 'AIPromptAPI'
}
})
// Connect the AI handler to this session
win.webContents.session.registerLocalAIHandler(aiHandler)
win.loadFile('index.html')
})
```
### 3. Use the Prompt API in your renderer
Your web content can now use the standard `LanguageModel` API, which is a
global available in the renderer:
```html title='index.html (Renderer Process)'
<script>
async function askAI () {
const model = await LanguageModel.create()
const response = await model.prompt('What is Electron?')
document.getElementById('response').textContent = response
}
</script>
<button onclick="askAI()">Ask AI</button>
<p id="response"></p>
```
## Implementing a real model
The quick-start example returns a hardcoded string. A real implementation
would integrate with a local model. See [`electron/llm`](https://github.com/electron/llm)
for an example of using `node-llama-cpp` to wire up GGUF (GPT-Generated Unified Format) models.
## Clearing the handler
To disconnect the AI handler from a session, pass `null`:
```js @ts-type={win:Electron.BrowserWindow}
win.webContents.session.registerLocalAIHandler(null)
```
After clearing, any `LanguageModel.create()` calls from renderers using that
session will fail.
## Security considerations
The `details` object passed to your handler includes `webContentsId` and
`securityOrigin`. Use these to decide whether to handle a request, and
when to reuse a model instance versus providing a fresh instance to
provide proper isolation between origins.
## Further reading
- [`localAIHandler` API reference](../api/local-ai-handler.md)
- [`LanguageModel` API reference](../api/language-model.md)
- [`ses.registerLocalAIHandler()`](../api/session.md#sesregisterlocalaihandlerhandler-experimental)
- [`utilityProcess.fork()`](../api/utility-process.md#utilityprocessforkmodulepath-args-options)
- [`electron/llm`](https://github.com/electron/llm)

View File

@@ -30,6 +30,8 @@ auto_filenames = {
"docs/api/ipc-main-service-worker.md",
"docs/api/ipc-main.md",
"docs/api/ipc-renderer.md",
"docs/api/language-model.md",
"docs/api/local-ai-handler.md",
"docs/api/menu-item.md",
"docs/api/menu.md",
"docs/api/message-channel-main.md",
@@ -108,6 +110,14 @@ auto_filenames = {
"docs/api/structures/jump-list-item.md",
"docs/api/structures/keyboard-event.md",
"docs/api/structures/keyboard-input-event.md",
"docs/api/structures/language-model-append-options.md",
"docs/api/structures/language-model-clone-options.md",
"docs/api/structures/language-model-create-core-options.md",
"docs/api/structures/language-model-create-options.md",
"docs/api/structures/language-model-expected.md",
"docs/api/structures/language-model-message-content.md",
"docs/api/structures/language-model-message.md",
"docs/api/structures/language-model-prompt-options.md",
"docs/api/structures/media-access-permission-request.md",
"docs/api/structures/memory-info.md",
"docs/api/structures/memory-usage-details.md",
@@ -400,6 +410,8 @@ auto_filenames = {
"lib/common/init.ts",
"lib/common/webpack-globals-provider.ts",
"lib/utility/api/exports/electron.ts",
"lib/utility/api/language-model.ts",
"lib/utility/api/local-ai-handler.ts",
"lib/utility/api/module-list.ts",
"lib/utility/api/net.ts",
"lib/utility/init.ts",

View File

@@ -748,6 +748,8 @@ filenames = {
"shell/services/node/node_service.h",
"shell/services/node/parent_port.cc",
"shell/services/node/parent_port.h",
"shell/utility/api/electron_api_local_ai_handler.cc",
"shell/utility/api/electron_api_local_ai_handler.h",
"shell/utility/electron_content_utility_client.cc",
"shell/utility/electron_content_utility_client.h",
]

View File

@@ -2,7 +2,7 @@ import { fetchWithSession } from '@electron/internal/browser/api/net-fetch';
import { addIpcDispatchListeners } from '@electron/internal/browser/ipc-dispatch';
import * as deprecate from '@electron/internal/common/deprecate';
import { net } from 'electron/main';
import { net, type UtilityProcess } from 'electron/main';
const { fromPartition, fromPath, Session } = process._linkedBinding('electron_browser_session');
const { isDisplayMediaSystemPickerAvailable } = process._linkedBinding('electron_browser_desktop_capturer');
@@ -111,6 +111,12 @@ Session.prototype.removeExtension = deprecate.moveAPI(
'session.extensions.removeExtension'
);
Session.prototype.registerLocalAIHandler = function (handler: UtilityProcess | null) {
// We need to unwrap the userland `ForkUtilityProcess` object and get the underlying
// `ElectronInternal.UtilityProcessWrapper` before we call the C++ function
return this._registerLocalAIHandler(handler !== null ? (handler as any)._unwrapHandle() : null);
};
export default {
fromPartition,
fromPath,

View File

@@ -131,6 +131,10 @@ class ForkUtilityProcess extends EventEmitter implements Electron.UtilityProcess
return this.#stderr;
}
_unwrapHandle () {
return this.#handle;
}
postMessage (message: any, transfer?: MessagePortMain[]) {
if (Array.isArray(transfer)) {
transfer = transfer.map((o: any) => o instanceof MessagePortMain ? o._internalPort : o);

View File

@@ -0,0 +1,44 @@
interface LanguageModelConstructorValues {
contextUsage: number;
contextWindow: number;
}
export default class LanguageModel implements Electron.LanguageModel {
contextUsage: number;
contextWindow: number;
constructor (values: LanguageModelConstructorValues) {
this.contextUsage = values.contextUsage;
this.contextWindow = values.contextWindow;
}
static async create (): Promise<LanguageModel> {
return new LanguageModel({
contextUsage: 0,
contextWindow: 0
});
}
static async availability () {
return 'available';
}
async prompt () {
return '';
}
async append (): Promise<undefined> {}
async measureContextUsage () {
return 0;
}
async clone () {
return new LanguageModel({
contextUsage: this.contextUsage,
contextWindow: this.contextWindow
});
}
destroy () {}
}

View File

@@ -0,0 +1,3 @@
const binding = process._linkedBinding('electron_utility_local_ai_handler');
export const setPromptAPIHandler = binding.setPromptAPIHandler;

View File

@@ -1,5 +1,7 @@
// Utility side modules, please sort alphabetically.
export const utilityNodeModuleList: ElectronInternal.ModuleEntry[] = [
{ name: 'localAIHandler', loader: () => require('./local-ai-handler') },
{ name: 'LanguageModel', loader: () => require('./language-model') },
{ name: 'net', loader: () => require('./net') },
{ name: 'systemPreferences', loader: () => require('@electron/internal/browser/api/system-preferences') }
];

View File

@@ -1,6 +1,8 @@
import LanguageModel from '@electron/internal/utility/api/language-model';
import { ParentPort } from '@electron/internal/utility/parent-port';
import { EventEmitter } from 'events';
import { ReadableStream } from 'stream/web';
import { pathToFileURL } from 'url';
const v8Util = process._linkedBinding('electron_common_v8_util');
@@ -10,6 +12,11 @@ const entryScript: string = v8Util.getHiddenValue(process, '_serviceStartupScrip
// we need to restore it here.
process.argv.splice(1, 1, entryScript);
// These are used by C++ to more easily identify these objects.
v8Util.setHiddenValue(global, 'isReadableStream', (val: unknown) => val instanceof ReadableStream);
v8Util.setHiddenValue(global, 'isLanguageModel', (val: unknown) => val instanceof LanguageModel);
v8Util.setHiddenValue(global, 'isLanguageModelClass', (val: any) => Object.is(val, LanguageModel) || val?.prototype instanceof LanguageModel || false);
// Import common settings.
require('@electron/internal/common/init');

View File

@@ -150,3 +150,4 @@ fix_use_fresh_lazynow_for_onendworkitemimpl_after_didruntask.patch
fix_pulseaudio_stream_and_icon_names.patch
fix_fire_menu_popup_start_for_dynamically_created_aria_menus.patch
feat_allow_enabling_extensions_on_custom_protocols.patch
reject_prompt_api_promises_on_mojo_connection_disconnect.patch

View File

@@ -0,0 +1,168 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: David Sanders <dsanders11@ucsbalum.com>
Date: Wed, 1 Apr 2026 21:14:38 -0700
Subject: Reject Prompt API promises on Mojo connection disconnect
Without these changes to reject promises when the Mojo connection
disconnects, these promises will hang indefinitely if the Prompt
API handler is killed or unregistered.
This will be upstreamed to Chromium.
Change-Id: I89a6a076ae35cbaf12a93c517223a524bab3dff0
diff --git a/third_party/blink/renderer/modules/ai/language_model.cc b/third_party/blink/renderer/modules/ai/language_model.cc
index c176575f2dcc049e478d5388ae1934aa3bc59786..b8eda3dd8733b1e7e92f70a8c75678df4a0314c8 100644
--- a/third_party/blink/renderer/modules/ai/language_model.cc
+++ b/third_party/blink/renderer/modules/ai/language_model.cc
@@ -190,6 +190,9 @@ class CloneLanguageModelClient
client_remote;
receiver_.Bind(client_remote.InitWithNewPipeAndPassReceiver(),
language_model->GetTaskRunner());
+ receiver_.set_disconnect_handler(
+ BindOnce(&CloneLanguageModelClient::OnConnectionError,
+ WrapWeakPersistent(this)));
language_model_->GetAILanguageModelRemote()->Fork(std::move(client_remote));
}
~CloneLanguageModelClient() override = default;
@@ -232,6 +235,11 @@ class CloneLanguageModelClient
Cleanup();
}
+ void OnConnectionError() {
+ OnError(mojom::blink::AIManagerCreateClientError::kUnableToCreateSession,
+ /*quota_error_info=*/nullptr);
+ }
+
void ResetReceiver() override { receiver_.reset(); }
private:
@@ -262,6 +270,8 @@ class AppendClient : public GarbageCollected<AppendClient>,
mojo::PendingRemote<mojom::blink::ModelStreamingResponder> client_remote;
receiver_.Bind(client_remote.InitWithNewPipeAndPassReceiver(),
language_model->GetTaskRunner());
+ receiver_.set_disconnect_handler(
+ BindOnce(&AppendClient::OnConnectionError, WrapWeakPersistent(this)));
language_model_->GetAILanguageModelRemote()->Append(
std::move(prompts), std::move(client_remote));
}
@@ -317,6 +327,11 @@ class AppendClient : public GarbageCollected<AppendClient>,
Cleanup();
}
+ void OnConnectionError() {
+ OnError(ModelStreamingResponseStatus::kErrorSessionDestroyed,
+ /*quota_error_info=*/nullptr);
+ }
+
void OnStreaming(const String& text) override {
NOTREACHED() << "Append() should not invoke `OnStreaming()`";
}
@@ -761,6 +776,7 @@ void LanguageModel::ExecuteMeasureInputUsage(
ScriptPromiseResolver<IDLDouble>* resolver,
AbortSignal* signal,
Vector<mojom::blink::AILanguageModelPromptPtr> prompts) {
+ auto reject_fn = RejectOnDestruction(resolver, signal);
language_model_remote_->MeasureInputUsage(
std::move(prompts),
BindOnce(
@@ -783,7 +799,8 @@ void LanguageModel::ExecuteMeasureInputUsage(
}
resolver->Resolve(static_cast<double>(usage.value()));
},
- WrapPersistent(resolver), WrapPersistent(signal)));
+ WrapPersistent(resolver), WrapPersistent(signal))
+ .Then(std::move(reject_fn)));
}
bool LanguageModel::ValidateInput(ScriptState* script_state,
diff --git a/third_party/blink/renderer/modules/ai/language_model_create_client.cc b/third_party/blink/renderer/modules/ai/language_model_create_client.cc
index ddc6fcda3ffbdc271bcdebfbd85aa711c063fee2..2b63e9e77dc1ed4a0ed9a687527efdc104974e7a 100644
--- a/third_party/blink/renderer/modules/ai/language_model_create_client.cc
+++ b/third_party/blink/renderer/modules/ai/language_model_create_client.cc
@@ -509,6 +509,11 @@ void LanguageModelCreateClient::OnError(
Cleanup();
}
+void LanguageModelCreateClient::OnConnectionError() {
+ OnError(mojom::blink::AIManagerCreateClientError::kUnableToCreateSession,
+ /*quota_error_info=*/nullptr);
+}
+
void LanguageModelCreateClient::ResetReceiver() {
receiver_.reset();
}
@@ -524,6 +529,8 @@ void LanguageModelCreateClient::OnInitialPromptsResolved(
mojo::PendingRemote<mojom::blink::AIManagerCreateLanguageModelClient>
client_remote;
receiver_.Bind(client_remote.InitWithNewPipeAndPassReceiver(), task_runner_);
+ receiver_.set_disconnect_handler(
+ BindOnce(&LanguageModelCreateClient::OnConnectionError, WrapWeakPersistent(this)));
HeapMojoRemote<mojom::blink::AIManager>& ai_manager_remote =
AIInterfaceProxy::GetAIManagerRemote(GetExecutionContext());
diff --git a/third_party/blink/renderer/modules/ai/language_model_create_client.h b/third_party/blink/renderer/modules/ai/language_model_create_client.h
index 9ed8dfbefeccf1627d56f5ccc315f06071a63e25..7c8e823608883171c115676db151b70eb2fd055d 100644
--- a/third_party/blink/renderer/modules/ai/language_model_create_client.h
+++ b/third_party/blink/renderer/modules/ai/language_model_create_client.h
@@ -49,6 +49,8 @@ class LanguageModelCreateClient
// Process options and create, if the availability result is valid.
void Create(mojom::blink::ModelAvailabilityCheckResult result);
+ void OnConnectionError();
+
// Continue creation after any initial prompts were processed or rejected.
void OnInitialPromptsResolved(
Vector<mojom::blink::AILanguageModelExpectedPtr> expected_inputs,
diff --git a/third_party/blink/renderer/modules/ai/model_execution_responder.cc b/third_party/blink/renderer/modules/ai/model_execution_responder.cc
index 47b65b13adfab4b8f2597a23d38a386915643d1b..fa9b54e1069019a66b8dab6eb0efe5df8b34c11a 100644
--- a/third_party/blink/renderer/modules/ai/model_execution_responder.cc
+++ b/third_party/blink/renderer/modules/ai/model_execution_responder.cc
@@ -84,7 +84,10 @@ class Responder final : public GarbageCollected<Responder>,
mojo::PendingRemote<blink::mojom::blink::ModelStreamingResponder>
BindNewPipeAndPassRemote(
scoped_refptr<base::SequencedTaskRunner> task_runner) {
- return receiver_.BindNewPipeAndPassRemote(task_runner);
+ auto pending_remote = receiver_.BindNewPipeAndPassRemote(task_runner);
+ receiver_.set_disconnect_handler(
+ BindOnce(&Responder::OnConnectionError, WrapWeakPersistent(this)));
+ return pending_remote;
}
// `mojom::blink::ModelStreamingResponder` implementation.
@@ -144,6 +147,11 @@ class Responder final : public GarbageCollected<Responder>,
Cleanup();
}
+ void OnConnectionError() {
+ OnError(ModelStreamingResponseStatus::kErrorSessionDestroyed,
+ /*quota_error_info=*/nullptr);
+ }
+
void RecordResponseStatusMetrics(
mojom::blink::ModelStreamingResponseStatus status) {
base::UmaHistogramEnumeration(
@@ -235,7 +243,10 @@ class StreamingResponder final
mojo::PendingRemote<blink::mojom::blink::ModelStreamingResponder>
BindNewPipeAndPassRemote(
scoped_refptr<base::SequencedTaskRunner> task_runner) {
- return receiver_.BindNewPipeAndPassRemote(task_runner);
+ auto pending_remote = receiver_.BindNewPipeAndPassRemote(task_runner);
+ receiver_.set_disconnect_handler(BindOnce(
+ &StreamingResponder::OnConnectionError, WrapWeakPersistent(this)));
+ return pending_remote;
}
ReadableStream* CreateReadableStream() {
@@ -337,6 +348,11 @@ class StreamingResponder final
Cleanup();
}
+ void OnConnectionError() {
+ OnError(ModelStreamingResponseStatus::kErrorSessionDestroyed,
+ /*quota_error_info=*/nullptr);
+ }
+
void RecordResponseStatusMetrics(
mojom::blink::ModelStreamingResponseStatus status) {
base::UmaHistogramEnumeration(

View File

@@ -0,0 +1,187 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#include "shell/browser/ai/proxying_ai_manager.h"
#include <optional>
#include <utility>
#include "base/functional/bind.h"
#include "base/notimplemented.h"
#include "content/public/browser/browser_context.h"
#include "content/public/browser/render_frame_host.h"
#include "content/public/browser/weak_document_ptr.h"
#include "mojo/public/cpp/bindings/callback_helpers.h"
#include "shell/browser/api/electron_api_session.h"
#include "shell/browser/api/electron_api_web_contents.h"
#include "shell/browser/session_preferences.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_writer.mojom.h"
namespace electron {
ProxyingAIManager::ProxyingAIManager(content::BrowserContext* browser_context,
content::RenderFrameHost* rfh)
: browser_context_(browser_context),
rfh_(rfh ? rfh->GetWeakDocumentPtr() : content::WeakDocumentPtr()) {
auto* session_prefs =
SessionPreferences::FromBrowserContext(browser_context_);
if (session_prefs) {
ai_handler_changed_subscription_ =
session_prefs->AddAIHandlerChangedCallback(
base::BindRepeating(&ProxyingAIManager::OnAIHandlerChanged,
weak_ptr_factory_.GetWeakPtr()));
}
}
ProxyingAIManager::~ProxyingAIManager() = default;
void ProxyingAIManager::OnAIHandlerChanged() {
ai_manager_remote_.reset();
}
void ProxyingAIManager::AddReceiver(
mojo::PendingReceiver<blink::mojom::AIManager> receiver) {
receivers_.Add(this, std::move(receiver));
}
const mojo::Remote<blink::mojom::AIManager>&
ProxyingAIManager::GetAIManagerRemote(const SessionPreferences& session_prefs) {
if (!ai_manager_remote_.is_bound()) {
auto* local_ai_handler = session_prefs.GetLocalAIHandler().get();
if (local_ai_handler) {
auto* rfh = rfh_.AsRenderFrameHostIfValid();
DCHECK(rfh);
auto* web_contents = electron::api::WebContents::From(
content::WebContents::FromRenderFrameHost(rfh));
std::optional<int32_t> web_contents_id;
if (web_contents) {
web_contents_id = web_contents->ID();
}
local_ai_handler->BindAIManager(
web_contents_id, rfh->GetLastCommittedOrigin(),
ai_manager_remote_.BindNewPipeAndPassReceiver());
}
}
return ai_manager_remote_;
}
void ProxyingAIManager::CanCreateLanguageModel(
blink::mojom::AILanguageModelCreateOptionsPtr options,
CanCreateLanguageModelCallback callback) {
auto* session_prefs =
SessionPreferences::FromBrowserContext(browser_context_);
DCHECK(session_prefs);
// Default to unavailable. This ensures the callback is always invoked
// even if there is no registered utility process handler, or the
// process crashes.
auto cb = mojo::WrapCallbackWithDefaultInvokeIfNotRun(
std::move(callback),
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
// Proxy the call through to the utility process
auto& ai_manager = GetAIManagerRemote(*session_prefs);
if (ai_manager.is_bound()) {
ai_manager->CanCreateLanguageModel(std::move(options), std::move(cb));
}
}
void ProxyingAIManager::CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) {
auto* session_prefs =
SessionPreferences::FromBrowserContext(browser_context_);
DCHECK(session_prefs);
// Proxy the call through to the utility process
auto& ai_manager = GetAIManagerRemote(*session_prefs);
if (!ai_manager.is_bound()) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote(std::move(client));
client_remote->OnError(
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession,
/*quota_error_info=*/nullptr);
return;
}
ai_manager->CreateLanguageModel(std::move(client), std::move(options));
}
void ProxyingAIManager::CanCreateSummarizer(
blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void ProxyingAIManager::CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void ProxyingAIManager::GetLanguageModelParams(
GetLanguageModelParamsCallback callback) {
NOTIMPLEMENTED();
}
void ProxyingAIManager::CanCreateWriter(
blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void ProxyingAIManager::CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void ProxyingAIManager::CanCreateRewriter(
blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void ProxyingAIManager::CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void ProxyingAIManager::CanCreateProofreader(
blink::mojom::AIProofreaderCreateOptionsPtr options,
CanCreateProofreaderCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void ProxyingAIManager::CreateProofreader(
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient> client,
blink::mojom::AIProofreaderCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void ProxyingAIManager::AddModelDownloadProgressObserver(
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
observer_remote) {
NOTIMPLEMENTED();
}
} // namespace electron

View File

@@ -0,0 +1,103 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#ifndef ELECTRON_SHELL_BROWSER_AI_PROXYING_AI_MANAGER_H_
#define ELECTRON_SHELL_BROWSER_AI_PROXYING_AI_MANAGER_H_
#include "base/callback_list.h"
#include "base/memory/weak_ptr.h"
#include "base/supports_user_data.h"
#include "content/public/browser/weak_document_ptr.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver_set.h"
#include "shell/browser/session_preferences.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_writer.mojom-forward.h"
namespace content {
class BrowserContext;
class RenderFrameHost;
} // namespace content
namespace electron {
// Owned by the host of the document / service worker via `SupportUserData`.
// The browser-side implementation of `blink::mojom::AIManager`, which
// proxies requests to a utility process if the session has a registered
// handler.
class ProxyingAIManager : public base::SupportsUserData::Data,
public blink::mojom::AIManager {
public:
ProxyingAIManager(content::BrowserContext* browser_context,
content::RenderFrameHost* rfh);
ProxyingAIManager(const ProxyingAIManager&) = delete;
ProxyingAIManager& operator=(const ProxyingAIManager&) = delete;
~ProxyingAIManager() override;
void AddReceiver(mojo::PendingReceiver<blink::mojom::AIManager> receiver);
private:
// Lazily bind the AIManager remote so that the developer can
// set the local AI handler after this class is already created
[[nodiscard]] const mojo::Remote<blink::mojom::AIManager>& GetAIManagerRemote(
const SessionPreferences& session_prefs);
void OnAIHandlerChanged();
// `blink::mojom::AIManager` implementation.
void CanCreateLanguageModel(
blink::mojom::AILanguageModelCreateOptionsPtr options,
CanCreateLanguageModelCallback callback) override;
void CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) override;
void CanCreateSummarizer(blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) override;
void CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) override;
void GetLanguageModelParams(GetLanguageModelParamsCallback callback) override;
void CanCreateWriter(blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) override;
void CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) override;
void CanCreateRewriter(blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) override;
void CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) override;
void CanCreateProofreader(blink::mojom::AIProofreaderCreateOptionsPtr options,
CanCreateProofreaderCallback callback) override;
void CreateProofreader(
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient>
client,
blink::mojom::AIProofreaderCreateOptionsPtr options) override;
void AddModelDownloadProgressObserver(
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
observer_remote) override;
mojo::ReceiverSet<blink::mojom::AIManager> receivers_;
raw_ptr<content::BrowserContext> browser_context_;
content::WeakDocumentPtr rfh_;
mojo::Remote<blink::mojom::AIManager> ai_manager_remote_;
base::CallbackListSubscription ai_handler_changed_subscription_;
base::WeakPtrFactory<ProxyingAIManager> weak_ptr_factory_{this};
};
} // namespace electron
#endif // ELECTRON_SHELL_BROWSER_AI_PROXYING_AI_MANAGER_H_

View File

@@ -1555,6 +1555,26 @@ v8::Local<v8::Value> Session::ClearData(gin::Arguments* const args) {
return promise_handle;
}
void Session::RegisterLocalAIHandler(gin_helper::ErrorThrower thrower,
v8::Local<v8::Value> val) {
auto* isolate = JavascriptEnvironment::GetIsolate();
gin_helper::Handle<UtilityProcessWrapper> handler;
if (!(val->IsNull() || gin::ConvertFromV8(isolate, val, &handler))) {
thrower.ThrowTypeError("Must pass null or UtilityProcess");
return;
}
auto* prefs = SessionPreferences::FromBrowserContext(browser_context());
DCHECK(prefs);
if (!handler.IsEmpty()) {
prefs->SetLocalAIHandler(handler->GetWeakPtr());
} else {
prefs->SetLocalAIHandler(nullptr);
}
}
#if BUILDFLAG(ENABLE_BUILTIN_SPELLCHECKER)
base::Value Session::GetSpellCheckerLanguages() {
return browser_context_->prefs()
@@ -1841,6 +1861,7 @@ void Session::FillObjectTemplate(v8::Isolate* isolate,
.SetMethod("setCodeCachePath", &Session::SetCodeCachePath)
.SetMethod("clearCodeCaches", &Session::ClearCodeCaches)
.SetMethod("clearData", &Session::ClearData)
.SetMethod("_registerLocalAIHandler", &Session::RegisterLocalAIHandler)
.SetProperty("cookies", &Session::Cookies)
.SetProperty("extensions", &Session::Extensions)
.SetProperty("netLog", &Session::NetLog)

View File

@@ -19,6 +19,7 @@
#include "gin/wrappable.h"
#include "services/network/public/mojom/host_resolver.mojom-forward.h"
#include "services/network/public/mojom/ssl_config.mojom-forward.h"
#include "shell/browser/api/electron_api_utility_process.h"
#include "shell/browser/api/ipc_dispatcher.h"
#include "shell/browser/event_emitter_mixin.h"
#include "shell/browser/net/resolve_proxy_helper.h"
@@ -178,6 +179,8 @@ class Session final : public gin::Wrappable<Session>,
void SetCodeCachePath(gin::Arguments* args);
v8::Local<v8::Promise> ClearCodeCaches(const gin_helper::Dictionary& options);
v8::Local<v8::Value> ClearData(gin::Arguments* args);
void RegisterLocalAIHandler(gin_helper::ErrorThrower thrower,
v8::Local<v8::Value> val);
#if BUILDFLAG(ENABLE_BUILTIN_SPELLCHECKER)
base::Value GetSpellCheckerLanguages();
void SetSpellCheckerLanguages(gin_helper::ErrorThrower thrower,

View File

@@ -47,6 +47,10 @@
#include "base/win/windows_types.h"
#endif
#if BUILDFLAG(ENABLE_PROMPT_API)
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#endif // BUILDFLAG(ENABLE_PROMPT_API)
namespace electron {
namespace {
@@ -454,6 +458,19 @@ UtilityProcessWrapper::CreateURLLoaderFactoryParams() {
return params;
}
#if BUILDFLAG(ENABLE_PROMPT_API)
void UtilityProcessWrapper::BindAIManager(
std::optional<int32_t> web_contents_id,
const url::Origin& security_origin,
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager) {
auto params = node::mojom::BindAIManagerParams::New();
params->web_contents_id = web_contents_id;
params->security_origin = security_origin;
node_service_remote_->BindAIManager(std::move(params), std::move(ai_manager));
}
#endif // BUILDFLAG(ENABLE_PROMPT_API)
// static
raw_ptr<UtilityProcessWrapper> UtilityProcessWrapper::FromProcessId(
base::ProcessId pid) {

View File

@@ -15,6 +15,7 @@
#include "base/memory/weak_ptr.h"
#include "base/process/process_handle.h"
#include "content/public/browser/service_process_host.h"
#include "electron/buildflags/buildflags.h"
#include "mojo/public/cpp/bindings/message.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "shell/browser/event_emitter_mixin.h"
@@ -24,6 +25,10 @@
#include "shell/services/node/public/mojom/node_service.mojom.h"
#include "v8/include/v8-forward.h"
#if BUILDFLAG(ENABLE_PROMPT_API)
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#endif // BUILDFLAG(ENABLE_PROMPT_API)
namespace gin {
class Arguments;
} // namespace gin
@@ -58,6 +63,16 @@ class UtilityProcessWrapper final
static gin_helper::Handle<UtilityProcessWrapper> Create(gin::Arguments* args);
static raw_ptr<UtilityProcessWrapper> FromProcessId(base::ProcessId pid);
#if BUILDFLAG(ENABLE_PROMPT_API)
void BindAIManager(std::optional<int32_t> web_contents_id,
const url::Origin& security_origin,
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager);
#endif // BUILDFLAG(ENABLE_PROMPT_API)
base::WeakPtr<UtilityProcessWrapper> GetWeakPtr() {
return weak_factory_.GetWeakPtr();
}
void Shutdown(uint32_t exit_code);
// gin_helper::Wrappable

View File

@@ -230,12 +230,20 @@
#include "ui/webui/resources/cr_components/help_bubble/help_bubble.mojom.h" // nogncheck
#endif
#if BUILDFLAG(ENABLE_PROMPT_API)
#include "shell/browser/ai/proxying_ai_manager.h"
#endif // BUILDFLAG(ENABLE_PROMPT_API)
using content::BrowserThread;
namespace electron {
namespace {
#if BUILDFLAG(ENABLE_PROMPT_API)
const char kAIManagerUserDataKey[] = "ai_manager";
#endif // BUILDFLAG(ENABLE_PROMPT_API)
ElectronBrowserClient* g_browser_client = nullptr;
base::NoDestructor<std::string> g_io_thread_application_locale;
@@ -1580,6 +1588,26 @@ void ElectronBrowserClient::
#endif
}
#if BUILDFLAG(ENABLE_PROMPT_API)
// Refs
// https://source.chromium.org/chromium/chromium/src/+/main:chrome/browser/chrome_content_browser_client.cc;l=8724-8737;drc=74754be9d4550a487df006a51a33318245d37301
void ElectronBrowserClient::BindAIManager(
content::BrowserContext* browser_context,
base::SupportsUserData* context_user_data,
content::RenderFrameHost* rfh,
mojo::PendingReceiver<blink::mojom::AIManager> receiver) {
if (!context_user_data->GetUserData(kAIManagerUserDataKey)) {
context_user_data->SetUserData(
kAIManagerUserDataKey,
std::make_unique<ProxyingAIManager>(browser_context, rfh));
}
ProxyingAIManager* ai_manager = static_cast<ProxyingAIManager*>(
context_user_data->GetUserData(kAIManagerUserDataKey));
ai_manager->AddReceiver(std::move(receiver));
}
#endif // BUILDFLAG(ENABLE_PROMPT_API)
std::string ElectronBrowserClient::GetApplicationLocale() {
return BrowserThread::CurrentlyOn(BrowserThread::IO)
? *g_io_thread_application_locale

View File

@@ -274,6 +274,14 @@ class ElectronBrowserClient : public content::ContentBrowserClient,
const content::ServiceWorkerVersionBaseInfo& service_worker_version_info,
blink::AssociatedInterfaceRegistry& associated_registry) override;
#if BUILDFLAG(ENABLE_PROMPT_API)
void BindAIManager(
content::BrowserContext* browser_context,
base::SupportsUserData* context_user_data,
content::RenderFrameHost* rfh,
mojo::PendingReceiver<blink::mojom::AIManager> receiver) override;
#endif // BUILDFLAG(ENABLE_PROMPT_API)
bool HandleExternalProtocol(
const GURL& url,
content::WebContents::Getter web_contents_getter,

View File

@@ -30,6 +30,11 @@ SessionPreferences* SessionPreferences::FromBrowserContext(
return static_cast<SessionPreferences*>(context->GetUserData(&kLocatorKey));
}
base::CallbackListSubscription SessionPreferences::AddAIHandlerChangedCallback(
base::RepeatingClosure callback) {
return ai_handler_changed_callbacks_.Add(std::move(callback));
}
bool SessionPreferences::HasServiceWorkerPreloadScript() {
const auto& preloads = preload_scripts();
auto it = std::find_if(

View File

@@ -7,8 +7,11 @@
#include <vector>
#include "base/callback_list.h"
#include "base/files/file_path.h"
#include "base/memory/weak_ptr.h"
#include "base/supports_user_data.h"
#include "shell/browser/api/electron_api_utility_process.h"
#include "shell/browser/preload_script.h"
namespace content {
@@ -17,6 +20,10 @@ class BrowserContext;
namespace electron {
namespace api {
class UtilityProcessWrapper;
}
class SessionPreferences : public base::SupportsUserData::Data {
public:
static SessionPreferences* FromBrowserContext(
@@ -30,6 +37,18 @@ class SessionPreferences : public base::SupportsUserData::Data {
bool HasServiceWorkerPreloadScript();
const base::WeakPtr<api::UtilityProcessWrapper>& GetLocalAIHandler() const {
return local_ai_handler_;
}
void SetLocalAIHandler(base::WeakPtr<api::UtilityProcessWrapper> handler) {
local_ai_handler_ = handler;
ai_handler_changed_callbacks_.Notify();
}
base::CallbackListSubscription AddAIHandlerChangedCallback(
base::RepeatingClosure callback);
private:
SessionPreferences();
@@ -37,6 +56,8 @@ class SessionPreferences : public base::SupportsUserData::Data {
static int kLocatorKey;
std::vector<PreloadScript> preload_scripts_;
base::WeakPtr<api::UtilityProcessWrapper> local_ai_handler_;
base::RepeatingClosureList ai_handler_changed_callbacks_;
};
} // namespace electron

View File

@@ -25,6 +25,10 @@ bool IsPrintingEnabled() {
return BUILDFLAG(ENABLE_PRINTING);
}
bool IsPromptAPIEnabled() {
return BUILDFLAG(ENABLE_PROMPT_API);
}
bool IsExtensionsEnabled() {
return BUILDFLAG(ENABLE_ELECTRON_EXTENSIONS);
}
@@ -48,6 +52,7 @@ void Initialize(v8::Local<v8::Object> exports,
dict.SetMethod("isFakeLocationProviderEnabled",
&IsFakeLocationProviderEnabled);
dict.SetMethod("isPrintingEnabled", &IsPrintingEnabled);
dict.SetMethod("isPromptAPIEnabled", &IsPromptAPIEnabled);
dict.SetMethod("isComponentBuild", &IsComponentBuild);
dict.SetMethod("isExtensionsEnabled", &IsExtensionsEnabled);
}

View File

@@ -56,6 +56,16 @@ v8::Local<v8::Value> CustomEmit(v8::Isolate* isolate,
converted_args));
}
template <typename... Args>
v8::Local<v8::Value> CallMethod(v8::Isolate* isolate,
v8::Local<v8::Object> object,
const char* method_name,
Args&&... args) {
v8::EscapableHandleScope scope(isolate);
return scope.Escape(
CustomEmit(isolate, object, method_name, std::forward<Args>(args)...));
}
template <typename T, typename... Args>
v8::Local<v8::Value> CallMethod(v8::Isolate* isolate,
gin_helper::DeprecatedWrappable<T>* object,

View File

@@ -116,6 +116,7 @@
V(electron_browser_event_emitter) \
V(electron_browser_system_preferences) \
V(electron_common_net) \
V(electron_utility_local_ai_handler) \
V(electron_utility_parent_port)
#define ELECTRON_TESTING_BINDINGS(V) V(electron_common_testing)

View File

@@ -151,6 +151,22 @@ node::Environment* CreateEnvironment(v8::Isolate* isolate,
return env;
}
v8::Local<v8::Object> CreateAbortController(v8::Isolate* isolate) {
auto context = isolate->GetCurrentContext();
auto global_object = context->Global();
auto value =
global_object->Get(context, gin::StringToV8(isolate, "AbortController"))
.ToLocalChecked();
DCHECK(!value.IsEmpty() && value->IsObject());
DCHECK(value->IsFunction());
auto constructor = value.As<v8::Function>();
auto instance =
constructor->NewInstance(context, 0, nullptr).ToLocalChecked();
return instance;
}
ExplicitMicrotasksScope::ExplicitMicrotasksScope(v8::MicrotaskQueue* queue)
: microtask_queue_(queue), original_policy_(queue->microtasks_policy()) {
// In browser-like processes, some nested run loops (macOS usually) may

View File

@@ -66,6 +66,8 @@ node::Environment* CreateEnvironment(v8::Isolate* isolate,
node::EnvironmentFlags::Flags env_flags,
std::string_view process_type = "");
v8::Local<v8::Object> CreateAbortController(v8::Isolate* isolate);
// A scope that temporarily changes the microtask policy to explicit. Use this
// anywhere that can trigger Node.js or uv_run().
//

View File

@@ -11,8 +11,10 @@
#include "base/no_destructor.h"
#include "base/process/process.h"
#include "base/strings/utf_string_conversions.h"
#include "electron/buildflags/buildflags.h"
#include "electron/fuses.h"
#include "electron/mas.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "net/base/network_change_notifier.h"
#include "services/network/public/cpp/wrapper_shared_url_loader_factory.h"
#include "services/network/public/mojom/host_resolver.mojom.h"
@@ -30,6 +32,12 @@
#include "shell/common/crash_keys.h"
#endif
#if BUILDFLAG(ENABLE_PROMPT_API)
#include "shell/utility/ai/utility_ai_manager.h"
#include "url/gurl.h"
#include "url/origin.h"
#endif // BUILDFLAG(ENABLE_PROMPT_API)
namespace electron {
mojo::Remote<node::mojom::NodeServiceClient>& GetRemote() {
@@ -215,4 +223,15 @@ void NodeService::UpdateURLLoaderFactory(
params->use_network_observer_from_url_loader_factory);
}
#if BUILDFLAG(ENABLE_PROMPT_API)
void NodeService::BindAIManager(
node::mojom::BindAIManagerParamsPtr params,
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager) {
mojo::MakeSelfOwnedReceiver(
std::make_unique<UtilityAIManager>(params->web_contents_id,
params->security_origin),
std::move(ai_manager));
}
#endif // BUILDFLAG(ENABLE_PROMPT_API)
} // namespace electron

View File

@@ -7,6 +7,7 @@
#include <memory>
#include "electron/buildflags/buildflags.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver.h"
@@ -68,6 +69,12 @@ class NodeService : public node::mojom::NodeService {
void UpdateURLLoaderFactory(
node::mojom::URLLoaderFactoryParamsPtr params) override;
#if BUILDFLAG(ENABLE_PROMPT_API)
void BindAIManager(
node::mojom::BindAIManagerParamsPtr params,
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager) override;
#endif // BUILDFLAG(ENABLE_PROMPT_API)
private:
// This needs to be initialized first so that it can be destroyed last
// after the node::Environment is destroyed. This ensures that if

View File

@@ -2,6 +2,7 @@
# Use of this source code is governed by the MIT license that can be
# found in the LICENSE file.
import("//electron/buildflags/buildflags.gni")
import("//mojo/public/tools/bindings/mojom.gni")
mojom("mojom") {
@@ -11,4 +12,8 @@ mojom("mojom") {
"//sandbox/policy/mojom",
"//third_party/blink/public/mojom:mojom_core",
]
if (enable_prompt_api) {
enabled_features = [ "enable_prompt_api" ]
}
}

View File

@@ -8,7 +8,9 @@ import "mojo/public/mojom/base/file_path.mojom";
import "sandbox/policy/mojom/sandbox.mojom";
import "services/network/public/mojom/host_resolver.mojom";
import "services/network/public/mojom/url_loader_factory.mojom";
import "third_party/blink/public/mojom/ai/ai_manager.mojom";
import "third_party/blink/public/mojom/messaging/message_port_descriptor.mojom";
import "url/mojom/origin.mojom";
struct URLLoaderFactoryParams {
pending_remote<network.mojom.URLLoaderFactory> url_loader_factory;
@@ -24,6 +26,11 @@ struct NodeServiceParams {
URLLoaderFactoryParams url_loader_factory_params;
};
struct BindAIManagerParams {
int32? web_contents_id;
url.mojom.Origin security_origin;
};
interface NodeServiceClient {
OnV8FatalError(string location, string report);
};
@@ -34,4 +41,8 @@ interface NodeService {
pending_remote<NodeServiceClient> client_remote);
UpdateURLLoaderFactory(URLLoaderFactoryParams params);
[EnableIf=enable_prompt_api]
BindAIManager(BindAIManagerParams params,
pending_receiver<blink.mojom.AIManager> ai_manager);
};

View File

@@ -0,0 +1,726 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#include "shell/utility/ai/utility_ai_language_model.h"
#include <string_view>
#include "base/no_destructor.h"
#include "base/notimplemented.h"
#include "shell/browser/javascript_environment.h"
#include "shell/common/gin_converters/callback_converter.h"
#include "shell/common/gin_converters/std_converter.h"
#include "shell/common/gin_helper/dictionary.h"
#include "shell/common/gin_helper/event_emitter_caller.h"
#include "shell/common/node_includes.h"
#include "shell/common/node_util.h"
#include "shell/common/v8_util.h"
#include "shell/utility/ai/utility_ai_manager.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
namespace gin {
template <>
struct Converter<on_device_model::mojom::ResponseConstraintPtr> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
const on_device_model::mojom::ResponseConstraintPtr& val) {
if (val.is_null())
return v8::Undefined(isolate);
if (val->is_json_schema()) {
return v8::JSON::Parse(isolate->GetCurrentContext(),
StringToV8(isolate, val->get_json_schema()))
.ToLocalChecked();
} else if (val->is_regex()) {
return v8::RegExp::New(isolate->GetCurrentContext(),
StringToV8(isolate, val->get_regex()),
v8::RegExp::kNone)
.ToLocalChecked();
}
return v8::Undefined(isolate);
}
};
template <>
struct Converter<blink::mojom::AILanguageModelPromptRole> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
blink::mojom::AILanguageModelPromptRole value) {
switch (value) {
case blink::mojom::AILanguageModelPromptRole::kSystem:
return StringToV8(isolate, "system");
case blink::mojom::AILanguageModelPromptRole::kUser:
return StringToV8(isolate, "user");
case blink::mojom::AILanguageModelPromptRole::kAssistant:
return StringToV8(isolate, "assistant");
default:
return StringToV8(isolate, "unknown");
}
}
};
template <>
struct Converter<blink::mojom::AILanguageModelPromptContentPtr> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
const blink::mojom::AILanguageModelPromptContentPtr& val) {
if (val.is_null())
return v8::Undefined(isolate);
auto dict = gin::Dictionary::CreateEmpty(isolate);
if (val->is_text()) {
dict.Set("type", "text");
dict.Set("value", val->get_text());
} else if (val->is_bitmap()) {
// Convert the bitmap to an ArrayBuffer
// TODO - Are we going to make any guarantees about the shape of the image
// data?
SkBitmap& bitmap = val->get_bitmap();
const auto dst_info = SkImageInfo::MakeN32Premul(bitmap.dimensions());
const size_t dst_n_bytes = dst_info.computeMinByteSize();
auto dst_buf = v8::ArrayBuffer::New(isolate, dst_n_bytes);
if (!bitmap.readPixels(dst_info, dst_buf->Data(), dst_info.minRowBytes(),
0, 0)) {
auto err = v8::Exception::TypeError(
gin::StringToV8(isolate, "Invalid bitmap content in prompt"));
node::errors::TriggerUncaughtException(isolate, err, {});
}
dict.Set("type", "image");
dict.Set("value", dst_buf);
} else if (val->is_audio()) {
// Convert the audio data to an ArrayBuffer
// TODO - Are we going to make any guarantees about the shape of the audio
// data?
on_device_model::mojom::AudioDataPtr& audio_data = val->get_audio();
std::vector<float>& raw_data = audio_data->data;
const size_t dst_n_bytes =
sizeof(std::remove_reference_t<decltype(raw_data)>::value_type) *
raw_data.size();
auto dst_buf = v8::ArrayBuffer::New(isolate, dst_n_bytes);
UNSAFE_BUFFERS(
std::ranges::copy(raw_data, static_cast<char*>(dst_buf->Data())));
dict.Set("type", "audio");
dict.Set("value", dst_buf);
}
return ConvertToV8(isolate, dict);
}
};
v8::Local<v8::Value> Converter<blink::mojom::AILanguageModelPromptPtr>::ToV8(
v8::Isolate* isolate,
const blink::mojom::AILanguageModelPromptPtr& val) {
if (val.is_null())
return v8::Undefined(isolate);
auto dict = gin::Dictionary::CreateEmpty(isolate);
dict.Set("role", val->role);
dict.Set("content", val->content);
dict.Set("prefix", val->is_prefix);
return ConvertToV8(isolate, dict);
}
} // namespace gin
namespace electron {
namespace {
constexpr std::string_view kIsReadableStreamKey = "isReadableStream";
constexpr std::string_view kIsLanguageModelKey = "isLanguageModel";
constexpr std::string_view kIsLanguageModelClassKey = "isLanguageModelClass";
v8::Local<v8::Function> GetPrivateBoolean(v8::Isolate* const isolate,
const v8::Local<v8::Context>& context,
std::string_view key) {
auto binding_key = gin::StringToV8(isolate, key);
auto private_binding_key = v8::Private::ForApi(isolate, binding_key);
auto global_object = context->Global();
auto value =
global_object->GetPrivate(context, private_binding_key).ToLocalChecked();
if (value.IsEmpty() || !value->IsFunction()) {
LOG(FATAL) << "Attempted to get the '" << key
<< "' value but it was missing";
}
return value.As<v8::Function>();
}
bool IsReadableStream(v8::Isolate* isolate, v8::Local<v8::Value> val) {
static base::NoDestructor<v8::Global<v8::Function>> is_readable_stream;
auto context = isolate->GetCurrentContext();
if (is_readable_stream.get()->IsEmpty()) {
is_readable_stream->Reset(
isolate, GetPrivateBoolean(isolate, context, kIsReadableStreamKey));
}
v8::Local<v8::Value> args[] = {val};
v8::Local<v8::Value> result =
is_readable_stream->Get(isolate)
->Call(context, v8::Null(isolate), std::size(args), args)
.ToLocalChecked();
return result->IsBoolean() && result.As<v8::Boolean>()->Value();
}
uint64_t GetContextUsage(v8::Isolate* isolate,
v8::Local<v8::Object> language_model) {
auto context = isolate->GetCurrentContext();
v8::Local<v8::Value> val =
language_model->Get(context, gin::StringToV8(isolate, "contextUsage"))
.ToLocalChecked();
uint64_t token_count = 0;
if (val->IsNumber()) {
gin::ConvertFromV8(isolate, val, &token_count);
}
return token_count;
}
// Owns itself. Will live as long as there's more data to process
// and the Mojo remote is still connected.
class PromptResponder {
public:
PromptResponder(v8::Isolate* isolate,
v8::Local<v8::Value> value,
v8::Local<v8::Object> abort_controller,
v8::Local<v8::Object> language_model,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder,
UtilityAILanguageModel* model) {
abort_controller_.Reset(isolate, abort_controller);
language_model_.Reset(isolate, language_model);
responder_.Bind(std::move(pending_responder));
responder_.set_disconnect_handler(
base::BindOnce(&PromptResponder::DeleteThis, base::Unretained(this)));
destroy_subscription_ = model->AddDestroyObserver(base::BindRepeating(
&PromptResponder::OnModelDestroyed, base::Unretained(this)));
Respond(isolate, value);
}
// disable copy
PromptResponder(const PromptResponder&) = delete;
PromptResponder& operator=(const PromptResponder&) = delete;
private:
void OnModelDestroyed() {
// Drop the subscription since the model is already being destroyed.
destroy_subscription_ = {};
responder_->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
/*quota_error_info=*/nullptr);
DeleteThis();
}
void Respond(v8::Isolate* isolate, v8::Local<v8::Value> value) {
if (value->IsPromise()) {
auto promise = value.As<v8::Promise>();
auto then_cb = base::BindOnce(
[](base::WeakPtr<PromptResponder> weak_ptr, v8::Isolate* isolate,
v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->RespondImplementation(isolate, result);
}
},
weak_ptr_factory_.GetWeakPtr(), isolate);
auto catch_cb = base::BindOnce(
[](base::WeakPtr<PromptResponder> weak_ptr,
v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->SendError();
weak_ptr->DeleteThis();
}
},
weak_ptr_factory_.GetWeakPtr());
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
} else {
RespondImplementation(isolate, value);
}
}
void RespondImplementation(v8::Isolate* isolate, v8::Local<v8::Value> val) {
std::string response;
if (val->IsString() && gin::ConvertFromV8(isolate, val, &response)) {
responder_->OnStreaming(response);
uint64_t token_count =
GetContextUsage(isolate, language_model_.Get(isolate));
responder_->OnCompletion(
blink::mojom::ModelExecutionContextInfo::New(token_count));
completed_ = true;
DeleteThis();
} else if (IsReadableStream(isolate, val)) {
v8::Local<v8::Value> reader =
gin_helper::CallMethod(isolate, val.As<v8::Object>(), "getReader");
DCHECK(reader->IsObject());
readable_stream_reader_.Reset(isolate, reader.As<v8::Object>());
Read(isolate);
} else {
SendError();
DeleteThis();
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate, "Invalid return value from LanguageModel.prompt()"));
node::errors::TriggerUncaughtException(isolate, err, {});
}
}
void Read(v8::Isolate* isolate) {
v8::Local<v8::Value> val = gin_helper::CallMethod(
isolate, readable_stream_reader_.Get(isolate), "read");
DCHECK(val->IsPromise());
auto promise = val.As<v8::Promise>();
auto then_cb = base::BindOnce(
[](base::WeakPtr<PromptResponder> weak_ptr, v8::Isolate* isolate,
v8::Local<v8::Value> result) {
if (weak_ptr) {
CHECK(result->IsObject());
v8::Local<v8::Value> done =
result.As<v8::Object>()
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "done"))
.ToLocalChecked();
CHECK(done->IsBoolean());
if (done.As<v8::Boolean>()->Value()) {
uint64_t token_count = GetContextUsage(
isolate, weak_ptr->language_model_.Get(isolate));
weak_ptr->responder_->OnCompletion(
blink::mojom::ModelExecutionContextInfo::New(token_count));
weak_ptr->completed_ = true;
weak_ptr->DeleteThis();
} else {
v8::Local<v8::Value> val =
result.As<v8::Object>()
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "value"))
.ToLocalChecked();
DCHECK(val->IsString());
std::string value;
if (gin::ConvertFromV8(isolate, val, &value)) {
weak_ptr->responder_->OnStreaming(value);
weak_ptr->Read(isolate);
} else {
weak_ptr->SendError();
weak_ptr->DeleteThis();
}
}
}
},
weak_ptr_factory_.GetWeakPtr(), isolate);
auto catch_cb = base::BindOnce(
[](base::WeakPtr<PromptResponder> weak_ptr,
v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->SendError();
weak_ptr->DeleteThis();
}
},
weak_ptr_factory_.GetWeakPtr());
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
}
void SendError() {
responder_->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorUnknown,
/*quota_error_info=*/nullptr);
}
void DeleteThis() {
destroy_subscription_ = {};
weak_ptr_factory_.InvalidateWeakPtrs();
if (!completed_) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
if (!readable_stream_reader_.IsEmpty()) {
gin_helper::CallMethod(isolate, readable_stream_reader_.Get(isolate),
"cancel");
}
gin_helper::CallMethod(isolate, abort_controller_.Get(isolate), "abort");
}
delete this;
}
bool completed_ = false;
v8::Global<v8::Object> readable_stream_reader_;
v8::Global<v8::Object> abort_controller_;
v8::Global<v8::Object> language_model_;
mojo::Remote<blink::mojom::ModelStreamingResponder> responder_;
base::CallbackListSubscription destroy_subscription_;
base::WeakPtrFactory<PromptResponder> weak_ptr_factory_{this};
};
} // namespace
UtilityAILanguageModel::UtilityAILanguageModel(
v8::Local<v8::Object> language_model,
base::WeakPtr<UtilityAIManager> manager)
: manager_(std::move(manager)) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
language_model_.Reset(isolate, language_model);
responder_set_.set_disconnect_handler(
base::BindRepeating(&UtilityAILanguageModel::OnResponderDisconnect,
weak_ptr_factory_.GetWeakPtr()));
}
UtilityAILanguageModel::~UtilityAILanguageModel() {
if (!is_destroyed_) {
Destroy();
}
}
base::CallbackListSubscription UtilityAILanguageModel::AddDestroyObserver(
base::RepeatingClosure callback) {
return on_destroy_.Add(std::move(callback));
}
void UtilityAILanguageModel::OnResponderDisconnect(
mojo::RemoteSetElementId id) {
auto it = abort_controllers_.find(id);
if (it != abort_controllers_.end()) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
gin_helper::CallMethod(isolate, it->second.Get(isolate), "abort");
abort_controllers_.erase(it);
}
}
blink::mojom::ModelStreamingResponder* UtilityAILanguageModel::GetResponder(
mojo::RemoteSetElementId responder_id) {
return responder_set_.Get(responder_id);
}
// static
bool UtilityAILanguageModel::IsLanguageModel(v8::Isolate* isolate,
v8::Local<v8::Value> val) {
static base::NoDestructor<v8::Global<v8::Function>> is_language_model;
auto context = isolate->GetCurrentContext();
if (is_language_model.get()->IsEmpty()) {
is_language_model->Reset(
isolate, GetPrivateBoolean(isolate, context, kIsLanguageModelKey));
}
v8::Local<v8::Value> args[] = {val};
v8::Local<v8::Value> result =
is_language_model->Get(isolate)
->Call(context, v8::Null(isolate), std::size(args), args)
.ToLocalChecked();
return result->IsBoolean() && result.As<v8::Boolean>()->Value();
}
// static
bool UtilityAILanguageModel::IsLanguageModelClass(v8::Isolate* isolate,
v8::Local<v8::Value> val) {
static base::NoDestructor<v8::Global<v8::Function>> is_language_model_class;
auto context = isolate->GetCurrentContext();
if (is_language_model_class.get()->IsEmpty()) {
is_language_model_class->Reset(
isolate, GetPrivateBoolean(isolate, context, kIsLanguageModelClassKey));
}
v8::Local<v8::Value> args[] = {val};
v8::Local<v8::Value> result =
is_language_model_class->Get(isolate)
->Call(context, v8::Null(isolate), std::size(args), args)
.ToLocalChecked();
return result->IsBoolean() && result.As<v8::Boolean>()->Value();
}
void UtilityAILanguageModel::Prompt(
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
on_device_model::mojom::ResponseConstraintPtr constraint,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
if (is_destroyed_) {
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
/*quota_error_info=*/nullptr);
return;
}
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
auto options = gin_helper::Dictionary::CreateEmpty(isolate);
if (!constraint.is_null()) {
options.Set("responseConstraint", gin::ConvertToV8(isolate, constraint));
}
options.Set("signal", abort_controller
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "signal"))
.ToLocalChecked());
v8::Local<v8::Value> val = gin_helper::CallMethod(
isolate, language_model_.Get(isolate), "prompt", prompts, options);
new PromptResponder(isolate, val, abort_controller,
language_model_.Get(isolate),
std::move(pending_responder), this);
}
void UtilityAILanguageModel::Append(
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
if (is_destroyed_) {
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
/*quota_error_info=*/nullptr);
return;
}
mojo::RemoteSetElementId responder_id =
responder_set_.Add(std::move(pending_responder));
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
abort_controllers_.emplace(responder_id,
v8::Global<v8::Object>(isolate, abort_controller));
auto options = gin_helper::Dictionary::CreateEmpty(isolate);
options.Set("signal", abort_controller
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "signal"))
.ToLocalChecked());
v8::Local<v8::Value> val = gin_helper::CallMethod(
isolate, language_model_.Get(isolate), "append", prompts, options);
auto SendResponse =
[](base::WeakPtr<UtilityAILanguageModel> weak_ptr, v8::Isolate* isolate,
mojo::RemoteSetElementId responder_id, v8::Local<v8::Value> result) {
if (!weak_ptr)
return;
weak_ptr->abort_controllers_.erase(responder_id);
blink::mojom::ModelStreamingResponder* responder =
weak_ptr->GetResponder(responder_id);
if (!responder) {
return;
}
uint64_t token_count =
GetContextUsage(isolate, weak_ptr->language_model_.Get(isolate));
responder->OnCompletion(
blink::mojom::ModelExecutionContextInfo::New(token_count));
};
if (val->IsPromise()) {
auto promise = val.As<v8::Promise>();
auto then_cb = base::BindOnce(SendResponse, weak_ptr_factory_.GetWeakPtr(),
isolate, responder_id);
auto catch_cb = base::BindOnce(
[](base::WeakPtr<UtilityAILanguageModel> weak_ptr,
mojo::RemoteSetElementId responder_id, v8::Local<v8::Value> result) {
if (!weak_ptr)
return;
weak_ptr->abort_controllers_.erase(responder_id);
blink::mojom::ModelStreamingResponder* responder =
weak_ptr->GetResponder(responder_id);
if (!responder) {
return;
}
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorUnknown,
/*quota_error_info=*/nullptr);
},
weak_ptr_factory_.GetWeakPtr(), responder_id);
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
} else {
// The method is supposed to return a promise, but for
// convenience allow developers to return a value directly
SendResponse(weak_ptr_factory_.GetWeakPtr(), isolate, responder_id, val);
}
}
void UtilityAILanguageModel::Fork(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client) {
if (is_destroyed_ || !manager_) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote(std::move(client));
client_remote->OnError(
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession,
/*quota_error_info=*/nullptr);
return;
}
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
manager_->CreateLanguageModelInternal(
isolate, std::move(client), language_model_.Get(isolate), "clone",
gin_helper::Dictionary::CreateEmpty(isolate),
blink::mojom::AILanguageModelCreateOptions::New());
}
void UtilityAILanguageModel::Destroy() {
if (is_destroyed_) {
return;
}
is_destroyed_ = true;
// Notify observers (e.g. in-progress PromptResponders) before
// tearing down the responder set and abort controllers.
on_destroy_.Notify();
for (auto& responder : responder_set_) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
/*quota_error_info=*/nullptr);
}
responder_set_.Clear();
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
for (auto& [id, controller] : abort_controllers_) {
gin_helper::CallMethod(isolate, controller.Get(isolate), "abort");
}
abort_controllers_.clear();
for (auto& controller : measure_abort_controllers_) {
gin_helper::CallMethod(isolate, controller.Get(isolate), "abort");
}
measure_abort_controllers_.clear();
gin_helper::CallMethod(isolate, language_model_.Get(isolate), "destroy");
}
void UtilityAILanguageModel::MeasureInputUsage(
std::vector<blink::mojom::AILanguageModelPromptPtr> input,
MeasureInputUsageCallback callback) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
measure_abort_controllers_.emplace_back(isolate, abort_controller);
auto abort_it = std::prev(measure_abort_controllers_.end());
auto options = gin_helper::Dictionary::CreateEmpty(isolate);
options.Set("signal", abort_controller
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "signal"))
.ToLocalChecked());
v8::Local<v8::Value> val =
gin_helper::CallMethod(isolate, language_model_.Get(isolate),
"measureContextUsage", input, options);
auto RunCallback = [](base::WeakPtr<UtilityAILanguageModel> weak_ptr,
std::list<v8::Global<v8::Object>>::iterator abort_it,
v8::Isolate* isolate,
MeasureInputUsageCallback callback,
v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->measure_abort_controllers_.erase(abort_it);
}
uint32_t input_tokens = 0;
if (result->IsNumber() &&
gin::ConvertFromV8(isolate, result, &input_tokens)) {
std::move(callback).Run(std::move(input_tokens));
} else if (result->IsNull()) {
std::move(callback).Run(std::nullopt);
} else {
std::move(callback).Run(std::nullopt);
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate,
"Invalid return value from LanguageModel.measureContextUsage()"));
node::errors::TriggerUncaughtException(isolate, err, {});
}
};
if (val->IsPromise()) {
auto promise = val.As<v8::Promise>();
auto split_callback = base::SplitOnceCallback(std::move(callback));
auto then_cb =
base::BindOnce(RunCallback, weak_ptr_factory_.GetWeakPtr(), abort_it,
isolate, std::move(split_callback.first));
auto catch_cb = base::BindOnce(
[](base::WeakPtr<UtilityAILanguageModel> weak_ptr,
std::list<v8::Global<v8::Object>>::iterator abort_it,
MeasureInputUsageCallback callback, v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->measure_abort_controllers_.erase(abort_it);
}
std::move(callback).Run(std::nullopt);
},
weak_ptr_factory_.GetWeakPtr(), abort_it,
std::move(split_callback.second));
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
} else {
// The method is supposed to return a promise, but for
// convenience allow developers to return a value directly
RunCallback(weak_ptr_factory_.GetWeakPtr(), abort_it, isolate,
std::move(callback), val);
}
}
} // namespace electron

View File

@@ -0,0 +1,99 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#ifndef ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_LANGUAGE_MODEL_H_
#define ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_LANGUAGE_MODEL_H_
#include <list>
#include <vector>
#include "base/callback_list.h"
#include "base/memory/weak_ptr.h"
#include "gin/converter.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "v8/include/v8.h"
namespace electron {
class UtilityAIManager;
class UtilityAILanguageModel : public blink::mojom::AILanguageModel {
public:
UtilityAILanguageModel(v8::Local<v8::Object> language_model,
base::WeakPtr<UtilityAIManager> manager);
UtilityAILanguageModel(const UtilityAILanguageModel&) = delete;
UtilityAILanguageModel& operator=(const UtilityAILanguageModel&) = delete;
~UtilityAILanguageModel() override;
// Subscribe to be notified when this model is destroyed. The returned
// subscription auto-unregisters when destroyed.
[[nodiscard]] base::CallbackListSubscription AddDestroyObserver(
base::RepeatingClosure callback);
static bool IsLanguageModel(v8::Isolate* isolate, v8::Local<v8::Value> val);
static bool IsLanguageModelClass(v8::Isolate* isolate,
v8::Local<v8::Value> val);
// `blink::mojom::AILanguageModel` implementation.
void Prompt(std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
on_device_model::mojom::ResponseConstraintPtr constraint,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) override;
void Append(std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) override;
void Fork(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client) override;
void Destroy() override;
void MeasureInputUsage(
std::vector<blink::mojom::AILanguageModelPromptPtr> input,
MeasureInputUsageCallback callback) override;
private:
void OnResponderDisconnect(mojo::RemoteSetElementId id);
blink::mojom::ModelStreamingResponder* GetResponder(
mojo::RemoteSetElementId responder_id);
base::WeakPtr<UtilityAIManager> manager_;
v8::Global<v8::Object> language_model_;
bool is_destroyed_ = false;
mojo::RemoteSet<blink::mojom::ModelStreamingResponder> responder_set_;
// Maps each in-progress Prompt/Append responder to its AbortController
// so we can abort the JS-side operation if the responder disconnects.
absl::flat_hash_map<mojo::RemoteSetElementId, v8::Global<v8::Object>>
abort_controllers_;
// Tracks abort controllers for in-progress MeasureInputUsage calls.
std::list<v8::Global<v8::Object>> measure_abort_controllers_;
// Notified when this model is destroyed, allowing in-progress
// PromptResponder instances to clean up.
base::RepeatingClosureList on_destroy_;
base::WeakPtrFactory<UtilityAILanguageModel> weak_ptr_factory_{this};
};
} // namespace electron
namespace gin {
template <>
struct Converter<blink::mojom::AILanguageModelPromptPtr> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
const blink::mojom::AILanguageModelPromptPtr& val);
};
} // namespace gin
#endif // ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_LANGUAGE_MODEL_H_

View File

@@ -0,0 +1,523 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#include "shell/utility/ai/utility_ai_manager.h"
#include <optional>
#include <utility>
#include "base/containers/fixed_flat_map.h"
#include "base/notimplemented.h"
#include "mojo/public/cpp/bindings/unique_receiver_set.h"
#include "shell/browser/javascript_environment.h"
#include "shell/common/gin_converters/callback_converter.h"
#include "shell/common/gin_converters/std_converter.h"
#include "shell/common/gin_helper/dictionary.h"
#include "shell/common/gin_helper/event_emitter_caller.h"
#include "shell/common/node_includes.h"
#include "shell/common/node_util.h"
#include "shell/utility/ai/utility_ai_language_model.h"
#include "shell/utility/api/electron_api_local_ai_handler.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_writer.mojom.h"
#include "url/gurl.h"
#include "url/origin.h"
#include "v8/include/v8.h"
namespace gin {
template <>
struct Converter<blink::mojom::ModelAvailabilityCheckResult> {
static bool FromV8(v8::Isolate* isolate,
v8::Local<v8::Value> val,
blink::mojom::ModelAvailabilityCheckResult* out) {
using Result = blink::mojom::ModelAvailabilityCheckResult;
static constexpr auto Lookup =
base::MakeFixedFlatMap<std::string_view, Result>({
{"available", Result::kAvailable},
{"unavailable", Result::kUnavailableUnknown},
{"downloading", Result::kDownloading},
{"downloadable", Result::kDownloadable},
});
return FromV8WithLookup(isolate, val, Lookup, out);
}
};
template <>
struct Converter<blink::mojom::AILanguageModelPromptType> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
blink::mojom::AILanguageModelPromptType value) {
switch (value) {
case blink::mojom::AILanguageModelPromptType::kText:
return StringToV8(isolate, "text");
case blink::mojom::AILanguageModelPromptType::kImage:
return StringToV8(isolate, "image");
case blink::mojom::AILanguageModelPromptType::kAudio:
return StringToV8(isolate, "audio");
default:
return StringToV8(isolate, "unknown");
}
}
};
template <>
struct Converter<blink::mojom::AILanguageCodePtr> {
static v8::Local<v8::Value> ToV8(v8::Isolate* isolate,
const blink::mojom::AILanguageCodePtr& val) {
if (val.is_null()) {
return v8::Undefined(isolate);
}
return StringToV8(isolate, val->code);
}
};
template <>
struct Converter<blink::mojom::AILanguageModelExpectedPtr> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
const blink::mojom::AILanguageModelExpectedPtr& val) {
if (val.is_null()) {
return v8::Undefined(isolate);
}
auto dict = gin::Dictionary::CreateEmpty(isolate);
dict.Set("type", val->type);
if (val->languages.has_value() && !val->languages->empty()) {
dict.Set("languages", val->languages.value());
}
return ConvertToV8(isolate, dict);
}
};
template <>
struct Converter<blink::mojom::AILanguageModelCreateOptionsPtr> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
const blink::mojom::AILanguageModelCreateOptionsPtr& val) {
if (val.is_null() ||
(val->sampling_params.is_null() && !val->expected_inputs.has_value() &&
!val->expected_outputs.has_value() && val->initial_prompts.empty())) {
return v8::Undefined(isolate);
}
auto dict = gin::Dictionary::CreateEmpty(isolate);
if (val->expected_inputs.has_value() && !val->expected_inputs->empty()) {
dict.Set("expectedInputs", val->expected_inputs.value());
}
if (val->expected_outputs.has_value() && !val->expected_outputs->empty()) {
dict.Set("expectedOutputs", val->expected_outputs.value());
}
if (!val->initial_prompts.empty()) {
dict.Set("initialPrompts", val->initial_prompts);
}
return ConvertToV8(isolate, dict);
}
};
} // namespace gin
namespace electron {
UtilityAIManager::UtilityAIManager(std::optional<int32_t> web_contents_id,
const url::Origin& security_origin)
: web_contents_id_(web_contents_id), security_origin_(security_origin) {
create_model_client_set_.set_disconnect_with_reason_handler(
base::BindRepeating(
&UtilityAIManager::OnCreateLanguageModelClientDisconnect,
weak_ptr_factory_.GetWeakPtr()));
}
UtilityAIManager::~UtilityAIManager() {
// Trigger the abort signal for any in-progress CreateLanguageModel calls
for (auto it = create_model_client_set_.begin();
it != create_model_client_set_.end(); ++it) {
OnCreateLanguageModelClientDisconnect(it.id(), 0, std::string());
}
}
void UtilityAIManager::OnCreateLanguageModelClientDisconnect(
mojo::RemoteSetElementId id,
uint32_t custom_reason,
const std::string& description) {
auto it = abort_controllers_.find(id);
if (it != abort_controllers_.end()) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
if (description.empty()) {
gin_helper::CallMethod(isolate, it->second.Get(isolate), "abort");
} else {
gin_helper::CallMethod(isolate, it->second.Get(isolate), "abort",
description);
}
abort_controllers_.erase(it);
}
}
v8::Global<v8::Object>& UtilityAIManager::GetLanguageModelClass() {
if (language_model_class_.IsEmpty()) {
auto& handler = electron::api::local_ai_handler::GetPromptAPIHandler();
if (handler.has_value()) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
auto details = gin_helper::Dictionary::CreateEmpty(isolate);
if (web_contents_id_.has_value()) {
details.Set("webContentsId", web_contents_id_.value());
} else {
details.Set("webContentsId", nullptr);
}
details.Set("securityOrigin", security_origin_.Serialize());
v8::Local<v8::Value> val = handler->Run(details);
if (val->IsPromise()) {
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate, "Cannot return a promise from the handler"));
node::errors::TriggerUncaughtException(isolate, err, {});
return language_model_class_;
}
if (!val->IsObject() ||
!val->ToObject(isolate->GetCurrentContext())
.ToLocalChecked()
->IsConstructor() ||
!UtilityAILanguageModel::IsLanguageModelClass(isolate, val)) {
auto err = v8::Exception::TypeError(
gin::StringToV8(isolate, "Must provide a constructible class"));
node::errors::TriggerUncaughtException(isolate, err, {});
return language_model_class_;
} else {
language_model_class_.Reset(
isolate,
val->ToObject(isolate->GetCurrentContext()).ToLocalChecked());
}
}
}
return language_model_class_;
}
void UtilityAIManager::SendCreateLanguageModelError(
mojo::RemoteSetElementId client_id,
blink::mojom::AIManagerCreateClientError error) {
abort_controllers_.erase(client_id);
blink::mojom::AIManagerCreateLanguageModelClient* client =
create_model_client_set_.Get(client_id);
if (!client) {
return;
}
client->OnError(error, /*quota_error_info=*/nullptr);
}
void UtilityAIManager::HandleLanguageModelResult(
v8::Isolate* isolate,
v8::Local<v8::Object> language_model,
mojo::RemoteSetElementId client_id,
blink::mojom::AILanguageModelCreateOptionsPtr options) {
abort_controllers_.erase(client_id);
gin_helper::Dictionary dict;
uint64_t context_usage = 0;
uint64_t context_quota = 0;
if (!ConvertFromV8(isolate, language_model, &dict) ||
!dict.Get("contextUsage", &context_usage) ||
!dict.Get("contextWindow", &context_quota)) {
SendCreateLanguageModelError(
client_id,
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
return;
}
// TODO - How can the implementation specify the supported prompt types? For
// now, assume all types are supported if the handler returns a valid object
base::flat_set<blink::mojom::AILanguageModelPromptType> enabled_input_types;
if (options->expected_inputs.has_value()) {
for (const auto& expected_input : options->expected_inputs.value()) {
enabled_input_types.insert(expected_input->type);
}
}
blink::mojom::AIManagerCreateLanguageModelClient* client =
create_model_client_set_.Get(client_id);
if (!client) {
return;
}
mojo::PendingRemote<blink::mojom::AILanguageModel> language_model_remote;
language_model_receivers_.Add(
std::make_unique<UtilityAILanguageModel>(language_model,
weak_ptr_factory_.GetWeakPtr()),
language_model_remote.InitWithNewPipeAndPassReceiver());
client->OnResult(
std::move(language_model_remote),
blink::mojom::AILanguageModelInstanceInfo::New(
context_quota, context_usage,
blink::mojom::AILanguageModelSamplingParams::New(),
std::vector<blink::mojom::AILanguageModelPromptType>(
enabled_input_types.begin(), enabled_input_types.end()),
/*audio_sample_rate_hz=*/std::nullopt,
/*audio_channel_count=*/std::nullopt));
}
void UtilityAIManager::CanCreateLanguageModel(
blink::mojom::AILanguageModelCreateOptionsPtr options,
CanCreateLanguageModelCallback callback) {
v8::Global<v8::Object>& language_model_class = GetLanguageModelClass();
blink::mojom::ModelAvailabilityCheckResult availability =
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown;
if (language_model_class.IsEmpty()) {
std::move(callback).Run(availability);
} else {
// If a handler is set, we can create a language model.
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
v8::Local<v8::Value> val = gin_helper::CallMethod(
isolate, language_model_class.Get(isolate), "availability", options);
auto RunCallback = [](v8::Isolate* isolate,
CanCreateLanguageModelCallback callback,
v8::Local<v8::Value> result) {
blink::mojom::ModelAvailabilityCheckResult availability =
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown;
if (result->IsString() &&
gin::ConvertFromV8(isolate, result, &availability)) {
std::move(callback).Run(availability);
} else {
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate, "Invalid return value from LanguageModel.availability()"));
node::errors::TriggerUncaughtException(isolate, err, {});
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
};
if (val->IsPromise()) {
auto promise = val.As<v8::Promise>();
auto split_callback = base::SplitOnceCallback(std::move(callback));
auto then_cb =
base::BindOnce(RunCallback, isolate, std::move(split_callback.first));
auto catch_cb = base::BindOnce(
[](CanCreateLanguageModelCallback callback,
v8::Local<v8::Value> result) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableUnknown);
},
std::move(split_callback.second));
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
} else {
// The method is supposed to return a promise, but for
// convenience allow developers to return a value directly
RunCallback(isolate, std::move(callback), val);
}
}
}
void UtilityAIManager::CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) {
v8::Global<v8::Object>& language_model_class = GetLanguageModelClass();
// Can't create language model if there's no language model class
if (language_model_class.IsEmpty()) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote(std::move(client));
client_remote->OnError(
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession,
/*quota_error_info=*/nullptr);
return;
}
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
gin_helper::Dictionary options_dict{
isolate, gin::ConvertToV8(isolate, options).As<v8::Object>()};
CreateLanguageModelInternal(isolate, std::move(client),
language_model_class.Get(isolate), "create",
std::move(options_dict), std::move(options));
}
void UtilityAIManager::CreateLanguageModelInternal(
v8::Isolate* isolate,
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
v8::Local<v8::Object> target,
std::string_view method_name,
gin_helper::Dictionary options_dict,
blink::mojom::AILanguageModelCreateOptionsPtr options) {
DCHECK(method_name == "create" || method_name == "clone");
std::string error_source = "LanguageModel." + std::string(method_name) + "()";
mojo::RemoteSetElementId client_id =
create_model_client_set_.Add(std::move(client));
// Store the abort controller so the disconnect handler can abort it.
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
abort_controllers_.emplace(client_id,
v8::Global<v8::Object>(isolate, abort_controller));
options_dict.Set("signal", abort_controller
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "signal"))
.ToLocalChecked());
v8::Local<v8::Value> val =
gin_helper::CallMethod(isolate, target, method_name.data(), options_dict);
if (val->IsPromise()) {
auto promise = val.As<v8::Promise>();
auto then_cb = base::BindOnce(
[](base::WeakPtr<UtilityAIManager> weak_ptr, v8::Isolate* isolate,
mojo::RemoteSetElementId client_id,
blink::mojom::AILanguageModelCreateOptionsPtr options,
std::string error_source, v8::Local<v8::Value> result) {
if (weak_ptr) {
if (result->IsObject() &&
UtilityAILanguageModel::IsLanguageModel(isolate, result)) {
weak_ptr->HandleLanguageModelResult(
isolate, result.As<v8::Object>(), client_id,
std::move(options));
} else {
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate, "Invalid return value from " + error_source));
node::errors::TriggerUncaughtException(isolate, err, {});
weak_ptr->SendCreateLanguageModelError(
client_id, blink::mojom::AIManagerCreateClientError::
kUnableToCreateSession);
}
}
},
weak_ptr_factory_.GetWeakPtr(), isolate, client_id, std::move(options),
std::string(error_source));
auto catch_cb = base::BindOnce(
[](base::WeakPtr<UtilityAIManager> weak_ptr,
mojo::RemoteSetElementId client_id, v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->SendCreateLanguageModelError(
client_id, blink::mojom::AIManagerCreateClientError::
kUnableToCreateSession);
}
},
weak_ptr_factory_.GetWeakPtr(), client_id);
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
} else if (val->IsObject() &&
UtilityAILanguageModel::IsLanguageModel(isolate, val)) {
// The method is supposed to return a promise, but for
// convenience allow developers to return a value directly
HandleLanguageModelResult(isolate, val.As<v8::Object>(), client_id,
std::move(options));
} else {
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate, "Invalid return value from " + std::string(error_source)));
node::errors::TriggerUncaughtException(isolate, err, {});
SendCreateLanguageModelError(
client_id,
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
}
}
void UtilityAIManager::CanCreateSummarizer(
blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void UtilityAIManager::CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void UtilityAIManager::GetLanguageModelParams(
GetLanguageModelParamsCallback callback) {
NOTIMPLEMENTED();
}
void UtilityAIManager::CanCreateWriter(
blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void UtilityAIManager::CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void UtilityAIManager::CanCreateRewriter(
blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void UtilityAIManager::CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void UtilityAIManager::CanCreateProofreader(
blink::mojom::AIProofreaderCreateOptionsPtr options,
CanCreateProofreaderCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void UtilityAIManager::CreateProofreader(
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient> client,
blink::mojom::AIProofreaderCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void UtilityAIManager::AddModelDownloadProgressObserver(
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
observer_remote) {
NOTIMPLEMENTED();
}
} // namespace electron

View File

@@ -0,0 +1,125 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#ifndef ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_MANAGER_H_
#define ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_MANAGER_H_
#include <optional>
#include <string>
#include <string_view>
#include "base/memory/weak_ptr.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "mojo/public/cpp/bindings/unique_receiver_set.h"
#include "services/on_device_model/public/mojom/download_observer.mojom-forward.h"
#include "shell/common/gin_helper/dictionary.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_writer.mojom-forward.h"
#include "url/origin.h"
#include "v8/include/v8.h"
namespace electron {
class UtilityAILanguageModel;
// The utility-side implementation of `blink::mojom::AIManager`.
class UtilityAIManager : public blink::mojom::AIManager {
public:
UtilityAIManager(std::optional<int32_t> web_contents_id,
const url::Origin& security_origin);
UtilityAIManager(const UtilityAIManager&) = delete;
UtilityAIManager& operator=(const UtilityAIManager&) = delete;
~UtilityAIManager() override;
private:
friend class UtilityAILanguageModel;
void OnCreateLanguageModelClientDisconnect(mojo::RemoteSetElementId id,
uint32_t custom_reason,
const std::string& description);
[[nodiscard]] v8::Global<v8::Object>& GetLanguageModelClass();
void SendCreateLanguageModelError(
mojo::RemoteSetElementId client_id,
blink::mojom::AIManagerCreateClientError error);
void CreateLanguageModelInternal(
v8::Isolate* isolate,
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
v8::Local<v8::Object> target,
std::string_view method_name,
gin_helper::Dictionary options_dict,
blink::mojom::AILanguageModelCreateOptionsPtr options);
void HandleLanguageModelResult(
v8::Isolate* isolate,
v8::Local<v8::Object> language_model,
mojo::RemoteSetElementId client_id,
blink::mojom::AILanguageModelCreateOptionsPtr options);
// `blink::mojom::AIManager` implementation.
void CanCreateLanguageModel(
blink::mojom::AILanguageModelCreateOptionsPtr options,
CanCreateLanguageModelCallback callback) override;
void CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) override;
void CanCreateSummarizer(blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) override;
void CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) override;
void GetLanguageModelParams(GetLanguageModelParamsCallback callback) override;
void CanCreateWriter(blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) override;
void CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) override;
void CanCreateRewriter(blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) override;
void CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) override;
void CanCreateProofreader(blink::mojom::AIProofreaderCreateOptionsPtr options,
CanCreateProofreaderCallback callback) override;
void CreateProofreader(
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient>
client,
blink::mojom::AIProofreaderCreateOptionsPtr options) override;
void AddModelDownloadProgressObserver(
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
observer_remote) override;
std::optional<int32_t> web_contents_id_;
url::Origin security_origin_;
v8::Global<v8::Object> language_model_class_;
mojo::RemoteSet<blink::mojom::AIManagerCreateLanguageModelClient>
create_model_client_set_;
// Maps each in-progress CreateLanguageModel client to its AbortController
// so we can abort the JS-side operation if the client disconnects.
absl::flat_hash_map<mojo::RemoteSetElementId, v8::Global<v8::Object>>
abort_controllers_;
// Owns all created UtilityAILanguageModel instances
mojo::UniqueReceiverSet<blink::mojom::AILanguageModel>
language_model_receivers_;
base::WeakPtrFactory<UtilityAIManager> weak_ptr_factory_{this};
};
} // namespace electron
#endif // ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_MANAGER_H_

View File

@@ -0,0 +1,53 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#include "shell/utility/api/electron_api_local_ai_handler.h"
#include <optional>
#include "base/no_destructor.h"
#include "shell/common/gin_converters/callback_converter.h"
#include "shell/common/gin_helper/dictionary.h"
#include "shell/common/node_includes.h"
#include "v8/include/v8.h"
namespace electron::api::local_ai_handler {
void SetPromptAPIHandler(v8::Isolate* isolate, v8::Local<v8::Value> val) {
PromptAPIHandler handler;
if (!(val->IsNull() || gin::ConvertFromV8(isolate, val, &handler))) {
isolate->ThrowException(v8::Exception::TypeError(
gin::StringToV8(isolate, "Must pass null or function")));
return;
}
if (val->IsNull()) {
GetPromptAPIHandler() = std::nullopt;
} else {
GetPromptAPIHandler() = handler;
}
}
std::optional<PromptAPIHandler>& GetPromptAPIHandler() {
static base::NoDestructor<std::optional<PromptAPIHandler>> prompt_api_handler;
return *prompt_api_handler;
}
} // namespace electron::api::local_ai_handler
namespace {
void Initialize(v8::Local<v8::Object> exports,
v8::Local<v8::Value> unused,
v8::Local<v8::Context> context,
void* priv) {
v8::Isolate* const isolate = v8::Isolate::GetCurrent();
gin_helper::Dictionary dict{isolate, exports};
dict.SetMethod("setPromptAPIHandler",
&electron::api::local_ai_handler::SetPromptAPIHandler);
}
} // namespace
NODE_LINKED_BINDING_CONTEXT_AWARE(electron_utility_local_ai_handler, Initialize)

View File

@@ -0,0 +1,28 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#ifndef ELECTRON_SHELL_UTILITY_API_ELECTRON_LOCAL_AI_HANDLER_H_
#define ELECTRON_SHELL_UTILITY_API_ELECTRON_LOCAL_AI_HANDLER_H_
#include <optional>
#include "base/functional/callback_forward.h"
#include "v8/include/v8-forward.h"
namespace gin_helper {
class Dictionary;
}
namespace electron::api::local_ai_handler {
using PromptAPIHandler =
base::RepeatingCallback<v8::Local<v8::Value>(gin_helper::Dictionary)>;
void SetPromptAPIHandler(v8::Isolate* isolate, v8::Local<v8::Value> value);
[[nodiscard]] std::optional<PromptAPIHandler>& GetPromptAPIHandler();
} // namespace electron::api::local_ai_handler
#endif // ELECTRON_SHELL_UTILITY_API_ELECTRON_LOCAL_AI_HANDLER_H_

View File

@@ -0,0 +1,995 @@
import { BrowserWindow, session, utilityProcess } from 'electron/main';
import { expect } from 'chai';
import { on, once } from 'node:events';
import * as path from 'node:path';
import { ifdescribe } from './lib/spec-helpers';
import { closeAllWindows } from './lib/window-helpers';
const features = process._linkedBinding('electron_common_features');
function getFixturePath (fixtureName: string) {
return path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), fixtureName);
}
// Await fn and listen for a message of the given type, returning the message once received
// Used to listen for a message triggered as a side effect of fn, where we don't care about the result of fn
async function listenForMessage (aiHandler: Electron.UtilityProcess, messageType: string, fn: () => Promise<void> | void) {
const messages = on(aiHandler, 'message');
await fn();
for await (const [message] of messages) {
if (message.type === messageType) {
return message;
}
}
return null;
}
// Call fn and await a message of the given type, returning the message and the promise returned by fn
// Used to listen for a message triggered as a side effect of fn, where we do care about the result of fn
async function waitForMessage (aiHandler: Electron.UtilityProcess, messageType: string, fn: () => Promise<unknown>) {
let promise: Promise<unknown>;
await listenForMessage(aiHandler, messageType, () => {
promise = fn();
});
return { promise: promise! };
}
ifdescribe(features.isPromptAPIEnabled())('localAIHandler module', () => {
const fixtures = path.resolve(__dirname, 'fixtures');
let w: Electron.BrowserWindow;
async function forkAndRegisterHandler (fixtureName: string) {
const aiHandler = utilityProcess.fork(getFixturePath(fixtureName));
await once(aiHandler, 'spawn');
w.webContents.session.registerLocalAIHandler(aiHandler);
return aiHandler;
}
async function sendControllableMessage (aiHandler: Electron.UtilityProcess, message: unknown) {
const ackEvent = once(aiHandler, 'message');
aiHandler.postMessage(message);
await ackEvent;
}
beforeEach(async () => {
w = new BrowserWindow({
show: false,
webPreferences: {
enableBlinkFeatures: 'AIPromptAPI,AIPromptAPIMultimodalInput'
}
});
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
});
afterEach(() => {
w.webContents.session.registerLocalAIHandler(null);
closeAllWindows();
});
describe('LanguageModel.availability()', () => {
it('is unavailable if invalid value returned', async () => {
await forkAndRegisterHandler('buggy-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('returns "available" when handler reports available', async () => {
await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
});
it('returns "downloadable" when handler reports downloadable', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'downloadable' });
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('downloadable');
});
it('returns "downloading" when handler reports downloading', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'downloading' });
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('downloading');
});
it('returns "unavailable" when handler reports unavailable', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'unavailable' });
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('returns "unavailable" when the availability() promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'reject' });
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('returns "unavailable" if the utility process dies', async () => {
const aiHandler = await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
aiHandler.kill();
await once(aiHandler, 'exit');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('returns "unavailable" if not registered', async () => {
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('returns "unavailable" if registered but utility process has not set handler', async () => {
await forkAndRegisterHandler('no-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('passes options to the availability() call', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'downloading' });
const options = { expectedInputs: [{ type: 'image' }, { type: 'text', languages: ['en', 'fr'] }], expectedOutputs: [{ type: 'image' }, { type: 'text', languages: ['en', 'fr'] }] };
const message = once(aiHandler, 'message');
await w.webContents.executeJavaScript(`LanguageModel.availability(${JSON.stringify(options)})`);
const [receivedMessage] = await message;
expect(receivedMessage.options).to.deep.equal(options);
expect(receivedMessage.type).to.equal('availability-called');
});
});
describe('LanguageModel.create()', () => {
async function expectRejectedWithError (message: string | RegExp, options?: Object) {
// Unwrap the error message because NotAllowedError won't serialize
if (options) {
await expect(w.webContents.executeJavaScript(`LanguageModel.create(${JSON.stringify(options)}).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
} else {
await expect(w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(message);
}
}
it('rejects if invalid value returned', async () => {
await forkAndRegisterHandler('buggy-language-model.js');
await expectRejectedWithError(/unable to create/);
});
it('rejects when no handler is registered', async () => {
await expectRejectedWithError(/unable to create/);
});
it('rejects when handler promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'reject' });
await expectRejectedWithError(/unable to create/);
});
it('rejects if the utility process dies during creation', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'create-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })');
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejectedWith(/unable to create/);
});
it('rejects if the handler gets unregistered during creation', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'create-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })');
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejectedWith(/unable to create/);
});
it('creates a LanguageModel instance from a valid handler', async () => {
await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model instanceof LanguageModel)')).to.equal(true);
});
it('passes initialPrompts to create()', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const options = { initialPrompts: [{ role: 'system', content: [{ type: 'text', value: 'You are Electron AI' }] }] };
const message = await listenForMessage(aiHandler, 'create-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create(${JSON.stringify(options)})`);
});
expect(message.options).to.have.property('signal');
delete message.options.signal;
expect(message.options).to.deep.equal({ initialPrompts: options.initialPrompts.map(prompt => ({ ...prompt, prefix: false })) });
});
it('passes expectedInputs and expectedOutputs options', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const options = { expectedInputs: [{ type: 'image' }, { type: 'text', languages: ['en', 'fr'] }], expectedOutputs: [{ type: 'text', languages: ['en', 'fr'] }] };
const message = await listenForMessage(aiHandler, 'create-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create(${JSON.stringify(options)})`);
});
expect(message.options).to.have.property('signal');
delete message.options.signal;
expect(message.options).to.deep.equal(options);
});
it('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'create-aborted', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create({ signal: AbortSignal.timeout(500) }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
it('exposes contextUsage and contextWindow on the created model', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage, contextWindow: model.contextWindow }))')).to.deep.equal({ contextUsage: 0, contextWindow: 12345 });
});
});
describe('LanguageModel.prompt()', () => {
async function expectRejectedWithError (message: string | RegExp, prompt: string, options?: Object) {
// Unwrap the error message because NotAllowedError won't serialize
if (options) {
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt(${JSON.stringify(prompt)}, ${JSON.stringify(options)})).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
} else {
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt(${JSON.stringify(prompt)})).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
}
}
it('rejects when handler returns an invalid value', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 99 });
await expectRejectedWithError(/error occurred/, 'Test prompt');
});
it('rejects when handler promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'reject' });
await expectRejectedWithError(/error occurred/, 'Test prompt');
});
it('rejects after the model has been destroyed', async () => {
await forkAndRegisterHandler('basic-language-model.js');
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.prompt("Test") }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the utility process dies during prompt', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Test")).catch(err => { throw err.message; })');
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the handler gets unregistered during prompt', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Test")).catch(err => { throw err.message; })');
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('returns a string response from the handler', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('foobar');
});
it('returns a ReadableStream response from the handler', async () => {
await forkAndRegisterHandler('streaming-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('Hello World');
});
it('passes string input to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt(\'hello world\'))');
});
expect(message.input).to.deep.equal([{ role: 'user', content: [{ type: 'text', value: 'hello world' }], prefix: false }]);
});
it('passes LanguageModelMessage[] input to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const input = [{ role: 'user', content: [{ type: 'text', value: 'hello' }] }, { role: 'assistant', content: [{ type: 'text', value: 'hi' }] }];
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt(${JSON.stringify(input)}))`);
});
expect(message.input).to.deep.equal(input.map(msg => ({ ...msg, prefix: false })));
});
it('passes responseConstraint option to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const responseConstraint = { type: 'object', properties: { name: { type: 'string' } } };
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt('test', { responseConstraint: ${JSON.stringify(responseConstraint)} }))`);
});
expect(message.options.responseConstraint).to.deep.equal(responseConstraint);
});
it('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'prompt-aborted', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("test", { signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
it('updates contextUsage after a prompt', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage }))')).to.deep.equal({ contextUsage: 0 });
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(async (model) => { await model.prompt("hello world"); return { contextUsage: model.contextUsage } })')).to.deep.equal({ contextUsage: 10 });
});
});
describe('LanguageModel.promptStreaming()', () => {
const collectStream = 'async (stream) => { const reader = stream.getReader(); let r = ""; while (true) { const { done, value } = await reader.read(); if (done) return r; r += value; } }';
async function expectRejectedWithError (message: string | RegExp, prompt: string, options?: Object) {
const collectStreamFn = collectStream;
// Unwrap the error message because NotAllowedError won't serialize
if (options) {
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStreamFn}; return collect(model.promptStreaming(${JSON.stringify(prompt)}, ${JSON.stringify(options)})); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
} else {
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStreamFn}; return collect(model.promptStreaming(${JSON.stringify(prompt)})); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
}
}
it('rejects when handler returns an invalid value', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 99 });
await expectRejectedWithError(/error occurred/, 'Test prompt');
});
it('rejects when ReadableStream returns an invalid value', async () => {
await forkAndRegisterHandler('buggy-streaming-language-model.js');
await expectRejectedWithError(/has been destroyed/, 'Test prompt');
});
it('rejects when handler promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'reject' });
await expectRejectedWithError(/error occurred/, 'Test prompt');
});
it('rejects after the model has been destroyed', async () => {
await forkAndRegisterHandler('basic-language-model.js');
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { model.destroy(); const collect = ${collectStream}; return collect(model.promptStreaming("Test")); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the utility process dies during prompt', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
return w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Test")); }).catch(err => { throw err.message; })`);
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the handler gets unregistered during prompt', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
return w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Test")); }).catch(err => { throw err.message; })`);
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('returns a string response from the handler', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Hi")); })`)).to.equal('foobar');
});
it('returns a ReadableStream response from the handler', async () => {
await forkAndRegisterHandler('streaming-language-model.js');
expect(await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Hi")); })`)).to.equal('Hello World');
});
it('passes string input to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming('hello world')); })`);
});
expect(message.input).to.deep.equal([{ role: 'user', content: [{ type: 'text', value: 'hello world' }], prefix: false }]);
});
it('passes LanguageModelMessage[] input to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const input = [{ role: 'user', content: [{ type: 'text', value: 'hello' }] }, { role: 'assistant', content: [{ type: 'text', value: 'hi' }] }];
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming(${JSON.stringify(input)})); })`);
});
expect(message.input).to.deep.equal(input.map(msg => ({ ...msg, prefix: false })));
});
it('passes responseConstraint option to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const responseConstraint = { type: 'object', properties: { name: { type: 'string' } } };
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming('test', { responseConstraint: ${JSON.stringify(responseConstraint)} })); })`);
});
expect(message.options.responseConstraint).to.deep.equal(responseConstraint);
});
it('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'prompt-aborted', async () => {
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("test", { signal: AbortSignal.timeout(500) })); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
it('updates contextUsage after a prompt', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage }))')).to.deep.equal({ contextUsage: 0 });
expect(await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; await collect(model.promptStreaming("hello world")); return { contextUsage: model.contextUsage }; })`)).to.deep.equal({ contextUsage: 10 });
});
});
describe('LanguageModel.append()', () => {
it('rejects when handler promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'reject' });
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/error occurred/);
});
it('rejects after the model has been destroyed', async () => {
await forkAndRegisterHandler('basic-language-model.js');
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.append("Test") }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the utility process dies during append', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'append-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })');
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the handler gets unregistered during append', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'append-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })');
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('appends a message without producing a response', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })')).to.be.undefined();
});
it('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'append-aborted', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("test", { signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
it('updates contextUsage after append', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage }))')).to.deep.equal({ contextUsage: 0 });
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(async (model) => { await model.append("hello world"); return { contextUsage: model.contextUsage } })')).to.deep.equal({ contextUsage: 5 });
});
});
describe('LanguageModel.measureContextUsage()', () => {
it('rejects if invalid value returned', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'invalid' });
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/usage cannot be calculated/);
});
it('rejects when handler promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'reject' });
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/usage cannot be calculated/);
});
it('rejects after the model has been destroyed', async () => {
await forkAndRegisterHandler('basic-language-model.js');
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.measureContextUsage("Test") }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the utility process dies during call', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'measure-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err?.message ?? "Unknown Error"; })');
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejected();
});
it('rejects if the handler gets unregistered during call', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'measure-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err?.message ?? "Unknown Error"; })');
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejected();
});
it('returns the token count for the given input', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("hello world"))')).to.equal(42);
});
// TODO(dsanders11): Upstream Chromium issue prevents this test from passing as
// there's no Mojo connection to disconnect trip abort signal
it.skip('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'measure-aborted', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("test", { signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
});
describe('LanguageModel.clone()', () => {
it('rejects when clone() returns a non-LanguageModel value', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'invalid' });
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/cannot be cloned/);
});
it('rejects when clone() promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'reject' });
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/cannot be cloned/);
});
it('rejects after the original model has been destroyed', async () => {
await forkAndRegisterHandler('basic-language-model.js');
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.clone(); }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the utility process dies during clone', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'clone-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })');
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejectedWith(/cannot be cloned/);
});
it('rejects if the handler gets unregistered during clone', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'clone-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })');
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejectedWith(/cannot be cloned/);
});
it('returns a new LanguageModel instance', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).then(cloned => cloned instanceof LanguageModel)')).to.equal(true);
});
it('preserves context from the original model', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript(`
LanguageModel.create().then(async (model) => {
await model.prompt("hello");
const cloned = await model.clone();
return { contextUsage: cloned.contextUsage, contextWindow: cloned.contextWindow };
})
`)).to.deep.equal({ contextUsage: 10, contextWindow: 12345 });
});
it('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'clone-aborted', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone({ signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
});
describe('LanguageModel.destroy()', () => {
it('destroys the model', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const message = await listenForMessage(aiHandler, 'destroy-called', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.prompt("Test"); }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
});
expect(message).not.null();
});
it('aborts any in-progress prompt calls', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
await w.webContents.executeJavaScript('LanguageModel.create().then(model => { window._model = model; })');
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
return w.webContents.executeJavaScript('window._model.prompt("Test").catch(err => { throw err.message; })');
});
const message = await listenForMessage(aiHandler, 'prompt-aborted', async () => {
await w.webContents.executeJavaScript('window._model.destroy()');
});
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
expect(message).not.null();
});
it('aborts any in-progress append calls', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
await w.webContents.executeJavaScript('LanguageModel.create().then(model => { window._model = model; })');
const { promise } = await waitForMessage(aiHandler, 'append-called', () => {
return w.webContents.executeJavaScript('window._model.append("Test").catch(err => { throw err.message; })');
});
const message = await listenForMessage(aiHandler, 'append-aborted', async () => {
await w.webContents.executeJavaScript('window._model.destroy()');
});
await w.webContents.executeJavaScript('window._model.destroy()');
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
expect(message).not.null();
});
it('can be called multiple times without error', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); model.destroy(); model.destroy(); return true; })')).to.equal(true);
});
});
describe('setPromptAPIHandler()', () => {
it('rejects if handler returns a promise', async () => {
await forkAndRegisterHandler('promise-handler-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('rejects if handler returns a non-class value', async () => {
await forkAndRegisterHandler('non-class-handler-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('rejects if handler returns a class not extending LanguageModel', async () => {
await forkAndRegisterHandler('non-language-model-handler.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('receives webContentsId in the details object', async () => {
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
const message = await listenForMessage(aiHandler, 'handler-called', async () => {
await w.webContents.executeJavaScript('LanguageModel.availability()');
});
expect(message.details).to.have.property('webContentsId', w.webContents.id);
});
it('receives securityOrigin in the details object', async () => {
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
const message = await listenForMessage(aiHandler, 'handler-called', async () => {
await w.webContents.executeJavaScript('LanguageModel.availability()');
});
expect(message.details).to.have.property('securityOrigin');
expect(message.details.securityOrigin).to.be.a('string').and.not.be.empty();
});
it('is called once per webContentsId and securityOrigin pair', async () => {
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
const message = await listenForMessage(aiHandler, 'handler-called', async () => {
await w.webContents.executeJavaScript('LanguageModel.availability()');
});
expect(message.callCount).to.equal(1);
// Calling availability again should not trigger the handler again
await w.webContents.executeJavaScript('LanguageModel.availability()');
// Create a second window with the same session - should trigger handler again (different webContentsId)
const w2 = new BrowserWindow({
show: false,
webPreferences: {
session: w.webContents.session,
enableBlinkFeatures: 'AIPromptAPI'
}
});
await w2.loadFile(path.join(fixtures, 'api', 'blank.html'));
const message2 = await listenForMessage(aiHandler, 'handler-called', async () => {
await w2.webContents.executeJavaScript('LanguageModel.availability()');
});
expect(message2.callCount).to.equal(2);
});
it('can be cleared by calling with null', async () => {
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
// Clear the handler inside the utility process
await sendControllableMessage(aiHandler, { command: 'clear-handler' });
// Existing Prompt API bindings should still work until the page is reloaded
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
// Load a new page to get a fresh Prompt API binding
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
// Should be unavailable since setPromptAPIHandler(null) was called in the utility process
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
});
describe('LanguageModel base class', () => {
it('provides default no-op implementations for all methods', async () => {
await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Hi"))')).to.be.undefined();
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Hi"))')).to.equal(0);
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).then(cloned => cloned instanceof LanguageModel)')).to.equal(true);
});
it('can use the base LanguageModel class directly without subclassing', async () => {
await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model instanceof LanguageModel)')).to.equal(true);
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage, contextWindow: model.contextWindow }))')).to.deep.equal({ contextUsage: 0, contextWindow: 0 });
});
});
describe('session isolation', () => {
it('applies to all windows using the same session', async () => {
await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
const w2 = new BrowserWindow({
show: false,
webPreferences: {
enableBlinkFeatures: 'AIPromptAPI'
}
});
await w2.loadFile(path.join(fixtures, 'api', 'blank.html'));
expect(await w2.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
});
it('different sessions can use different handler processes', async () => {
const ses1 = session.fromPartition('ai-isolation-1');
const ses2 = session.fromPartition('ai-isolation-2');
const w1 = new BrowserWindow({
show: false,
webPreferences: {
session: ses1,
enableBlinkFeatures: 'AIPromptAPI'
}
});
const w2 = new BrowserWindow({
show: false,
webPreferences: {
session: ses2,
enableBlinkFeatures: 'AIPromptAPI'
}
});
await Promise.all([
w1.loadFile(path.join(fixtures, 'api', 'blank.html')),
w2.loadFile(path.join(fixtures, 'api', 'blank.html'))
]);
const aiHandler1 = utilityProcess.fork(getFixturePath('basic-language-model.js'));
await once(aiHandler1, 'spawn');
ses1.registerLocalAIHandler(aiHandler1);
const aiHandler2 = utilityProcess.fork(getFixturePath('default-language-model.js'));
await once(aiHandler2, 'spawn');
ses2.registerLocalAIHandler(aiHandler2);
try {
// basic-language-model returns 'foobar'
expect(await w1.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('foobar');
// default-language-model returns ''
expect(await w2.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('');
} finally {
ses1.registerLocalAIHandler(null);
ses2.registerLocalAIHandler(null);
}
});
it('clearing one session handler does not affect another', async () => {
const ses1 = session.fromPartition('ai-isolation-clear-1');
const ses2 = session.fromPartition('ai-isolation-clear-2');
const w1 = new BrowserWindow({
show: false,
webPreferences: {
session: ses1,
enableBlinkFeatures: 'AIPromptAPI'
}
});
const w2 = new BrowserWindow({
show: false,
webPreferences: {
session: ses2,
enableBlinkFeatures: 'AIPromptAPI'
}
});
await Promise.all([
w1.loadFile(path.join(fixtures, 'api', 'blank.html')),
w2.loadFile(path.join(fixtures, 'api', 'blank.html'))
]);
const aiHandler1 = utilityProcess.fork(getFixturePath('basic-language-model.js'));
await once(aiHandler1, 'spawn');
ses1.registerLocalAIHandler(aiHandler1);
const aiHandler2 = utilityProcess.fork(getFixturePath('basic-language-model.js'));
await once(aiHandler2, 'spawn');
ses2.registerLocalAIHandler(aiHandler2);
try {
// Both should be available
expect(await w1.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
expect(await w2.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
// Clear handler for session 1
ses1.registerLocalAIHandler(null);
// Session 1 should be unavailable, session 2 should still be available
expect(await w1.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
expect(await w2.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
} finally {
ses1.registerLocalAIHandler(null);
ses2.registerLocalAIHandler(null);
}
});
});
});

View File

@@ -1,4 +1,4 @@
import { app, session, BrowserWindow, net, ipcMain, Session, webFrameMain, WebFrameMain } from 'electron/main';
import { app, session, BrowserWindow, net, ipcMain, Session, utilityProcess, webFrameMain, WebFrameMain } from 'electron/main';
import * as auth from 'basic-auth';
import { expect } from 'chai';
@@ -2132,4 +2132,96 @@ describe('session module', () => {
expect((await cookies.get({ url: 'https://example.org/', name: 'testdotorg' })).length).to.equal(0);
});
});
describe('ses.registerLocalAIHandler()', () => {
let w: Electron.BrowserWindow;
beforeEach(() => {
w = new BrowserWindow({
show: false,
webPreferences: {
enableBlinkFeatures: 'AIPromptAPI'
}
});
});
afterEach(() => {
w.webContents.session.registerLocalAIHandler(null);
closeAllWindows();
});
it('registers a utility process as the AI handler', async () => {
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
const aiHandler = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
w.webContents.session.registerLocalAIHandler(aiHandler);
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
});
it('clears the handler when called with null', async () => {
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
const { session } = w.webContents;
const aiHandler = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
session.registerLocalAIHandler(aiHandler);
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
session.registerLocalAIHandler(null);
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('prevents new LanguageModel.create() calls after clearing', async () => {
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
const { session } = w.webContents;
const aiHandler = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
session.registerLocalAIHandler(aiHandler);
await expect(w.webContents.executeJavaScript('LanguageModel.create()')).to.eventually.be.fulfilled();
session.registerLocalAIHandler(null);
await expect(w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/unable to create/);
});
it('can re-register a new handler after clearing', async () => {
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
const { session } = w.webContents;
const aiHandler1 = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
session.registerLocalAIHandler(aiHandler1);
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
session.registerLocalAIHandler(null);
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
const aiHandler2 = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
session.registerLocalAIHandler(aiHandler2);
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
await expect(w.webContents.executeJavaScript('LanguageModel.create()')).to.eventually.be.fulfilled();
});
it('throws when called with a non-UtilityProcess argument', () => {
const { session } = w.webContents;
expect(() => session.registerLocalAIHandler('not a process' as any)).to.throw();
expect(() => session.registerLocalAIHandler(42 as any)).to.throw();
expect(() => session.registerLocalAIHandler({} as any)).to.throw();
});
it('can register an existing handler again', async () => {
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
const { session } = w.webContents;
const aiHandler = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
session.registerLocalAIHandler(aiHandler);
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
session.registerLocalAIHandler(null);
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
session.registerLocalAIHandler(aiHandler);
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
});
});
});

View File

@@ -0,0 +1,28 @@
const { localAIHandler, LanguageModel } = require('electron/utility');
localAIHandler.setPromptAPIHandler(() => {
const BasicLanguageModel = class extends LanguageModel {
static async create () {
return new BasicLanguageModel({
contextUsage: 0,
contextWindow: 12345
});
}
async prompt () {
this.contextUsage += 10;
return 'foobar';
}
async append () {
this.contextUsage += 5;
}
async measureContextUsage () {
return 42;
}
};
return BasicLanguageModel;
});

View File

@@ -0,0 +1,15 @@
const { localAIHandler, LanguageModel } = require('electron/utility');
localAIHandler.setPromptAPIHandler(() => {
const BuggyLanguageModel = class extends LanguageModel {
static async create () {
return 'foobar';
}
static availability () {
return 'foobar';
}
};
return BuggyLanguageModel;
});

View File

@@ -0,0 +1,28 @@
const { localAIHandler, LanguageModel } = require('electron/utility');
const { ReadableStream } = require('node:stream/web');
localAIHandler.setPromptAPIHandler(() => {
const BuggyStreamingLanguageModel = class extends LanguageModel {
static async create () {
return new BuggyStreamingLanguageModel({
contextUsage: 0,
contextWindow: 0
});
}
async prompt () {
this.contextUsage += 10;
return new ReadableStream({
async start (controller) {
controller.enqueue('Hello ');
controller.enqueue(99);
controller.close();
}
});
}
};
return BuggyStreamingLanguageModel;
});

View File

@@ -0,0 +1,110 @@
const { localAIHandler, LanguageModel } = require('electron/utility');
const { once } = require('node:events');
let availabilityState = 'available';
let createResponse = null;
let promptResponse = 'Hello World';
let appendResponse = null;
let measureResponse = 100;
let cloneResponse = null;
process.parentPort.on('message', (e) => {
const { command, value } = e.data;
if (command === 'set-availability') {
availabilityState = value;
} else if (command === 'set-create') {
createResponse = value;
} else if (command === 'set-prompt-response') {
promptResponse = value;
} else if (command === 'set-append-response') {
appendResponse = value;
} else if (command === 'set-measure-response') {
measureResponse = value;
} else if (command === 'set-clone-response') {
cloneResponse = value;
}
process.parentPort.postMessage('ack');
});
async function waitForAbort (signal, messageType) {
await once(signal, 'abort');
process.parentPort.postMessage({ type: messageType });
throw new Error('Aborted');
}
localAIHandler.setPromptAPIHandler(() => {
const ControllableLanguageModel = class extends LanguageModel {
static async create (options) {
process.parentPort.postMessage({ type: 'create-called', options });
if (createResponse === 'reject') {
return Promise.reject(new Error('Model is unavailable'));
} else if (createResponse === 'wait-for-abort') {
await waitForAbort(options.signal, 'create-aborted');
}
return new ControllableLanguageModel({
contextUsage: 0,
contextWindow: 0
});
}
static async availability (options) {
process.parentPort.postMessage({ type: 'availability-called', options });
if (availabilityState === 'reject') {
return Promise.reject(new Error('Model is unavailable'));
}
return availabilityState;
}
async prompt (input, options) {
process.parentPort.postMessage({ type: 'prompt-called', input, options });
if (promptResponse === 'reject') {
return Promise.reject(new Error('Model is unavailable'));
} else if (promptResponse === 'wait-for-abort') {
await waitForAbort(options.signal, 'prompt-aborted');
}
return promptResponse;
}
async append (input, options) {
process.parentPort.postMessage({ type: 'append-called', input, options });
if (appendResponse === 'reject') {
return Promise.reject(new Error('Append failed'));
} else if (appendResponse === 'wait-for-abort') {
await waitForAbort(options.signal, 'append-aborted');
}
}
async measureContextUsage (input, options) {
process.parentPort.postMessage({ type: 'measure-called', input, options });
if (measureResponse === 'reject') {
return Promise.reject(new Error('Measure failed'));
} else if (measureResponse === 'wait-for-abort') {
await waitForAbort(options.signal, 'measure-aborted');
}
return measureResponse;
}
async clone (options) {
process.parentPort.postMessage({ type: 'clone-called', options });
if (cloneResponse === 'reject') {
return Promise.reject(new Error('Clone failed'));
} else if (cloneResponse === 'wait-for-abort') {
await waitForAbort(options.signal, 'clone-aborted');
} else if (cloneResponse === 'invalid') {
return 'not-a-language-model';
}
return new ControllableLanguageModel({
contextUsage: this.contextUsage,
contextWindow: this.contextWindow
});
}
destroy () {
process.parentPort.postMessage({ type: 'destroy-called' });
}
};
return ControllableLanguageModel;
});

View File

@@ -0,0 +1,5 @@
const { localAIHandler, LanguageModel } = require('electron/utility');
localAIHandler.setPromptAPIHandler(() => {
return LanguageModel;
});

View File

@@ -0,0 +1,18 @@
const { localAIHandler, LanguageModel } = require('electron/utility');
let callCount = 0;
process.parentPort.on('message', (e) => {
const { command } = e.data;
if (command === 'clear-handler') {
localAIHandler.setPromptAPIHandler(null);
}
process.parentPort.postMessage('ack');
});
localAIHandler.setPromptAPIHandler((details) => {
callCount++;
process.parentPort.postMessage({ type: 'handler-called', details, callCount });
return LanguageModel;
});

View File

@@ -0,0 +1,3 @@
process.parentPort.on('message', () => {
process.parentPort.postMessage('ack');
});

View File

@@ -0,0 +1,5 @@
const { localAIHandler } = require('electron/utility');
localAIHandler.setPromptAPIHandler(() => {
return 'not-a-class';
});

View File

@@ -0,0 +1,5 @@
const { localAIHandler } = require('electron/utility');
localAIHandler.setPromptAPIHandler(() => {
return class NotALanguageModel {};
});

View File

@@ -0,0 +1,5 @@
const { localAIHandler, LanguageModel } = require('electron/utility');
localAIHandler.setPromptAPIHandler(() => {
return Promise.resolve(LanguageModel);
});

View File

@@ -0,0 +1,28 @@
const { localAIHandler, LanguageModel } = require('electron/utility');
const { ReadableStream } = require('node:stream/web');
localAIHandler.setPromptAPIHandler(() => {
const StreamingLanguageModel = class extends LanguageModel {
static async create () {
return new StreamingLanguageModel({
contextUsage: 0,
contextWindow: 0
});
}
async prompt () {
this.contextUsage += 10;
return new ReadableStream({
async start (controller) {
controller.enqueue('Hello ');
controller.enqueue('World');
controller.close();
}
});
}
};
return StreamingLanguageModel;
});

View File

@@ -20,7 +20,8 @@ import {
session,
systemPreferences,
webContents,
TouchBar
TouchBar,
utilityProcess
} from 'electron/main';
import { clipboard, crashReporter, nativeImage, shell } from 'electron/common';
@@ -1266,6 +1267,9 @@ session.defaultSession.webRequest.onBeforeSendHeaders(filter, function (details:
callback({ cancel: false, requestHeaders: details.requestHeaders });
});
session.defaultSession.registerLocalAIHandler(utilityProcess.fork(path.join(__dirname, 'ai-handler.js')));
session.defaultSession.registerLocalAIHandler(null);
app.whenReady().then(function () {
const protocol = session.defaultSession.protocol;
protocol.registerFileProtocol('atom', function (request, callback) {

View File

@@ -1,6 +1,6 @@
/* eslint-disable */
import { net, systemPreferences } from 'electron/utility';
import { localAIHandler, net, systemPreferences, LanguageModel } from 'electron/utility';
process.parentPort.on('message', (e) => {
if (e.data === 'Hello from parent!') {
@@ -65,3 +65,25 @@ if (process.platform === 'darwin') {
const value2 = systemPreferences.getUserDefault('Foo', 'boolean');
console.log(value2);
}
// localAIHandler
// https://github.com/electron/electron/blob/main/docs/api/local-ai-handler.md
localAIHandler.setPromptAPIHandler((details) => {
return class MyLanguageModel extends LanguageModel {
private details = details;
static async create() {
return new MyLanguageModel({
contextUsage: 0,
contextWindow: 0,
})
}
async prompt () {
return "Hello World"
}
}
})
localAIHandler.setPromptAPIHandler(null);

View File

@@ -19,6 +19,7 @@ declare namespace NodeJS {
isPDFViewerEnabled(): boolean;
isFakeLocationProviderEnabled(): boolean;
isPrintingEnabled(): boolean;
isPromptAPIEnabled(): boolean;
isExtensionsEnabled(): boolean;
isComponentBuild(): boolean;
}

View File

@@ -146,6 +146,7 @@ declare namespace Electron {
interface Session {
_setDisplayMediaRequestHandler: Electron.Session['setDisplayMediaRequestHandler'];
_registerLocalAIHandler(handler: ElectronInternal.UtilityProcessWrapper | null): void;
}
type CreateWindowFunction = (options: BrowserWindowConstructorOptions) => WebContents;