mirror of
https://github.com/electron/electron.git
synced 2026-04-10 03:01:51 -04:00
Compare commits
1 Commits
remove-dec
...
feat/promp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39aed69a33 |
11
BUILD.gn
11
BUILD.gn
@@ -611,6 +611,17 @@ source_set("electron_lib") {
|
||||
]
|
||||
}
|
||||
|
||||
if (enable_prompt_api) {
|
||||
sources += [
|
||||
"shell/browser/ai/proxying_ai_manager.cc",
|
||||
"shell/browser/ai/proxying_ai_manager.h",
|
||||
"shell/utility/ai/utility_ai_language_model.cc",
|
||||
"shell/utility/ai/utility_ai_language_model.h",
|
||||
"shell/utility/ai/utility_ai_manager.cc",
|
||||
"shell/utility/ai/utility_ai_manager.h",
|
||||
]
|
||||
}
|
||||
|
||||
if (is_mac) {
|
||||
# Disable C++ modules to resolve linking error when including MacOS SDK
|
||||
# headers from third_party/electron_node/deps/uv/include/uv/darwin.h
|
||||
|
||||
@@ -12,6 +12,7 @@ buildflag_header("buildflags") {
|
||||
"ENABLE_PDF_VIEWER=$enable_pdf_viewer",
|
||||
"ENABLE_ELECTRON_EXTENSIONS=$enable_electron_extensions",
|
||||
"ENABLE_BUILTIN_SPELLCHECKER=$enable_builtin_spellchecker",
|
||||
"ENABLE_PROMPT_API=$enable_prompt_api",
|
||||
"OVERRIDE_LOCATION_PROVIDER=$enable_fake_location_provider",
|
||||
]
|
||||
|
||||
|
||||
@@ -17,6 +17,9 @@ declare_args() {
|
||||
# Enable Spellchecker support
|
||||
enable_builtin_spellchecker = true
|
||||
|
||||
# Enable Prompt API support.
|
||||
enable_prompt_api = true
|
||||
|
||||
# The version of Electron.
|
||||
# Packagers and vendor builders should set this in gn args to avoid running
|
||||
# the script that reads git tag.
|
||||
|
||||
85
docs/api/language-model.md
Normal file
85
docs/api/language-model.md
Normal file
@@ -0,0 +1,85 @@
|
||||
## Class: LanguageModel
|
||||
|
||||
> Implement local AI language models
|
||||
|
||||
Process: [Utility](../glossary.md#utility-process)
|
||||
|
||||
### `new LanguageModel(initialState)`
|
||||
|
||||
* `initialState` Object
|
||||
* `contextUsage` number
|
||||
* `contextWindow` number
|
||||
|
||||
> [!NOTE]
|
||||
> Do not use this constructor directly outside of the class itself, as it will not be properly connected to the `localAIHandler`
|
||||
|
||||
### Static Methods
|
||||
|
||||
The `LanguageModel` class has the following static methods:
|
||||
|
||||
#### `LanguageModel.create(options)` _Experimental_
|
||||
|
||||
* `options` [LanguageModelCreateOptions](structures/language-model-create-options.md)
|
||||
|
||||
Returns `Promise<LanguageModel>`. Creates a new `LanguageModel` with the provided `options`.
|
||||
|
||||
#### `LanguageModel.availability([options])` _Experimental_
|
||||
|
||||
* `options` [LanguageModelCreateCoreOptions](structures/language-model-create-core-options.md) (optional)
|
||||
|
||||
Returns `Promise<string>`
|
||||
|
||||
Determines the availability of the language model and returns one of the following strings:
|
||||
|
||||
* `available`
|
||||
* `downloadable`
|
||||
* `downloading`
|
||||
* `unavailable`
|
||||
|
||||
### Instance Properties
|
||||
|
||||
The following properties are available on instances of `LanguageModel`:
|
||||
|
||||
#### `languageModel.contextUsage` _Experimental_
|
||||
|
||||
A `number` representing how many tokens are currently in the context window.
|
||||
|
||||
#### `languageModel.contextWindow` _Experimental_
|
||||
|
||||
A `number` representing the size of the context window, in tokens.
|
||||
|
||||
### Instance Methods
|
||||
|
||||
The following methods are available on instances of `LanguageModel`:
|
||||
|
||||
#### `languageModel.prompt(input, options)` _Experimental_
|
||||
|
||||
* `input` [LanguageModelMessage[]](structures/language-model-message.md)
|
||||
* `options` [LanguageModelPromptOptions](structures/language-model-prompt-options.md)
|
||||
|
||||
Returns `Promise<string> | Promise<import('stream/web').ReadableStream<string>>`. Prompt the model for a response.
|
||||
|
||||
#### `languageModel.append(input, options)` _Experimental_
|
||||
|
||||
* `input` [LanguageModelMessage[]](structures/language-model-message.md)
|
||||
* `options` [LanguageModelAppendOptions](structures/language-model-append-options.md)
|
||||
|
||||
Returns `Promise<undefined>`. Append a message without prompting for a response.
|
||||
|
||||
#### `languageModel.measureContextUsage(input, options)` _Experimental_
|
||||
|
||||
* `input` [LanguageModelMessage[]](structures/language-model-message.md)
|
||||
* `options` [LanguageModelPromptOptions](structures/language-model-prompt-options.md)
|
||||
|
||||
Returns `Promise<number>`. Measure how many tokens the input would use.
|
||||
|
||||
#### `languageModel.clone(options)` _Experimental_
|
||||
|
||||
* `options` [LanguageModelCloneOptions](structures/language-model-clone-options.md)
|
||||
|
||||
Returns `Promise<LanguageModel>`. Clones the `LanguageModel` such that the
|
||||
context and initial prompt should be preserved.
|
||||
|
||||
#### `languageModel.destroy()` _Experimental_
|
||||
|
||||
Destroys the model, and any ongoing executions are aborted.
|
||||
26
docs/api/local-ai-handler.md
Normal file
26
docs/api/local-ai-handler.md
Normal file
@@ -0,0 +1,26 @@
|
||||
# localAIHandler
|
||||
|
||||
> Proxy built-in AI APIs to a local LLM implementation
|
||||
|
||||
Process: [Utility](../glossary.md#utility-process)
|
||||
|
||||
This module is intended to be used by a script registered to a session via
|
||||
[`ses.registerLocalAIHandler(handler)`](./session.md#sesregisterlocalaihandlerhandler-experimental)
|
||||
|
||||
## Methods
|
||||
|
||||
The `localAIHandler` module has the following methods:
|
||||
|
||||
#### `localAIHandler.setPromptAPIHandler(handler)` _Experimental_
|
||||
|
||||
* `handler` Function\<typeof [LanguageModel](language-model.md)\> | null
|
||||
* `details` Object
|
||||
* `webContentsId` Integer - The [unique id](web-contents.md#contentsid-readonly) of
|
||||
the [WebContents](web-contents.md) calling the Prompt API.
|
||||
* `securityOrigin` string - Origin of the page calling the Prompt API.
|
||||
|
||||
Sets the handler for new Prompt API binding requests from the renderer process. This happens
|
||||
once per pair of `webContentsId` and `securityOrigin`. Clearing the handler by calling
|
||||
`setPromptAPIHandler(null)` will prevent new Prompt API sessions from being started,
|
||||
but will not invalidate existing ones. If you want to invalidate existing Prompt API sessions,
|
||||
clear the local AI handler for the session using `ses.registerLocalAIHandler(null)`.
|
||||
@@ -1632,6 +1632,12 @@ This method clears more types of data and is more thorough than the
|
||||
|
||||
For more information, refer to Chromium's [`BrowsingDataRemover` interface][browsing-data-remover].
|
||||
|
||||
#### `ses.registerLocalAIHandler(handler)` _Experimental_
|
||||
|
||||
* `handler` [UtilityProcess](utility-process.md#class-utilityprocess) | null
|
||||
|
||||
Registers a local AI handler `UtilityProcess`. To clear the handler, call `registerLocalAIHandler(null)`, which will disconnect any existing Prompt API sessions and destroy any `LanguageModel` instances.
|
||||
|
||||
### Instance Properties
|
||||
|
||||
The following properties are available on instances of `Session`:
|
||||
|
||||
3
docs/api/structures/language-model-append-options.md
Normal file
3
docs/api/structures/language-model-append-options.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# LanguageModelAppendOptions Object
|
||||
|
||||
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)
|
||||
3
docs/api/structures/language-model-clone-options.md
Normal file
3
docs/api/structures/language-model-clone-options.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# LanguageModelCloneOptions Object
|
||||
|
||||
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)
|
||||
@@ -0,0 +1,4 @@
|
||||
# LanguageModelCreateCoreOptions Object
|
||||
|
||||
* `expectedInputs` [LanguageModelExpected[]](language-model-expected.md) (optional)
|
||||
* `expectedOutputs` [LanguageModelExpected[]](language-model-expected.md) (optional)
|
||||
4
docs/api/structures/language-model-create-options.md
Normal file
4
docs/api/structures/language-model-create-options.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# LanguageModelCreateOptions Object extends `LanguageModelCreateCoreOptions`
|
||||
|
||||
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)
|
||||
* `initialPrompts` [LanguageModelMessage[]](language-model-message.md) (optional)
|
||||
7
docs/api/structures/language-model-expected.md
Normal file
7
docs/api/structures/language-model-expected.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# LanguageModelExpected Object
|
||||
|
||||
* `type` string - Can be one of the following values:
|
||||
* `text`
|
||||
* `image`
|
||||
* `audio`
|
||||
* `languages` string[] (optional)
|
||||
7
docs/api/structures/language-model-message-content.md
Normal file
7
docs/api/structures/language-model-message-content.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# LanguageModelMessageContent Object
|
||||
|
||||
* `type` string - Can be one of the following values:
|
||||
* `text`
|
||||
* `image`
|
||||
* `audio`
|
||||
* `value` ArrayBuffer | string
|
||||
8
docs/api/structures/language-model-message.md
Normal file
8
docs/api/structures/language-model-message.md
Normal file
@@ -0,0 +1,8 @@
|
||||
# LanguageModelMessage Object
|
||||
|
||||
* `role` string - Can be one of the following values:
|
||||
* `system`
|
||||
* `user`
|
||||
* `assistant`
|
||||
* `content` [LanguageModelMessageContent[]](language-model-message-content.md)
|
||||
* `prefix` boolean (optional)
|
||||
4
docs/api/structures/language-model-prompt-options.md
Normal file
4
docs/api/structures/language-model-prompt-options.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# LanguageModelPromptOptions Object
|
||||
|
||||
* `responseConstraint` Object (optional)
|
||||
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)
|
||||
175
docs/tutorial/local-ai-handler.md
Normal file
175
docs/tutorial/local-ai-handler.md
Normal file
@@ -0,0 +1,175 @@
|
||||
---
|
||||
title: Local AI Handler
|
||||
description: Handle built-in AI APIs with a local LLM implementation
|
||||
slug: local-ai-handler
|
||||
hide_title: true
|
||||
---
|
||||
|
||||
# Local AI Handler
|
||||
|
||||
> **This API is experimental.** It may change or be removed in future Electron releases.
|
||||
|
||||
Electron supports [Prompt API](https://github.com/webmachinelearning/prompt-api)
|
||||
(`LanguageModel`) web API by letting you route calls to a local LLM running in a
|
||||
[utility process](../api/utility-process.md). Web content calls
|
||||
`LanguageModel.create()` and `LanguageModel.prompt()` like it would in any
|
||||
browser, while your Electron app decides which model handles the request.
|
||||
|
||||
## How it works
|
||||
|
||||
The local AI handler architecture involves three processes:
|
||||
|
||||
1. **Main process** — creates `UtilityProcess`, and then registers it to handle
|
||||
Prompt API calls for a given session via [`ses.registerLocalAIHandler()`](../api/session.md#sesregisterlocalaihandlerhandler-experimental).
|
||||
2. **Utility process** — runs a script that calls
|
||||
[`localAIHandler.setPromptAPIHandler()`](../api/local-ai-handler.md#localaihandlersetpromptapihandlerhandler-experimental)
|
||||
to supply a `LanguageModel` subclass.
|
||||
3. **Renderer process** — web content uses the standard `LanguageModel` API
|
||||
(e.g. `LanguageModel.create()`, `model.prompt()`).
|
||||
|
||||
When a renderer calls the Prompt API, Electron proxies the request through the
|
||||
main process to the registered utility process, which invokes your
|
||||
`LanguageModel` implementation and sends the result back directly to the renderer.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
The Prompt API Blink feature must be enabled on any `BrowserWindow` that will
|
||||
use it with the `AIPromptAPI` feature. To enable multi-modal inputs, add the
|
||||
`AIPromptAPIMultimodalInput` as well.
|
||||
|
||||
```js
|
||||
const win = new BrowserWindow({
|
||||
webPreferences: {
|
||||
enableBlinkFeatures: 'AIPromptAPI'
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## Quick start
|
||||
|
||||
### 1. Create the utility process script
|
||||
|
||||
The utility process script registers your `LanguageModel` subclass. The
|
||||
handler function receives a `details` object with information about the
|
||||
caller, and must return a class that extends `LanguageModel`.
|
||||
|
||||
```js title='ai-handler.js (Utility Process)'
|
||||
const { localAIHandler, LanguageModel } = require('electron/utility')
|
||||
|
||||
localAIHandler.setPromptAPIHandler((details) => {
|
||||
// details.webContentsId — ID of the calling WebContents
|
||||
// details.securityOrigin — origin of the calling page
|
||||
|
||||
return class MyLanguageModel extends LanguageModel {
|
||||
static async create (options) {
|
||||
// options.signal - AbortSignal to cancel the creation of the model
|
||||
// options.initialPrompts - initial prompts to pass to the language model
|
||||
|
||||
return new MyLanguageModel({
|
||||
contextUsage: 0,
|
||||
contextWindow: 4096
|
||||
})
|
||||
}
|
||||
|
||||
static async availability () {
|
||||
// Return 'available', 'downloadable', 'downloading', or 'unavailable'
|
||||
return 'available'
|
||||
}
|
||||
|
||||
async prompt (input) {
|
||||
// input is a string or LanguageModelMessage[]
|
||||
// Return a string response from your model, or a ReadableStream
|
||||
// to return a streaming response.
|
||||
return 'This is a response from your local LLM!'
|
||||
}
|
||||
|
||||
async clone () {
|
||||
return new MyLanguageModel({
|
||||
contextUsage: this.contextUsage,
|
||||
contextWindow: this.contextWindow
|
||||
})
|
||||
}
|
||||
|
||||
destroy () {
|
||||
// Clean up model resources
|
||||
}
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 2. Register the handler in the main process
|
||||
|
||||
Fork the utility process and register it as the AI handler for a session:
|
||||
|
||||
```js title='main.js (Main Process)'
|
||||
const { app, BrowserWindow, utilityProcess } = require('electron')
|
||||
|
||||
const path = require('node:path')
|
||||
|
||||
app.whenReady().then(() => {
|
||||
// Fork the utility process running your AI handler script
|
||||
const aiHandler = utilityProcess.fork(path.join(__dirname, 'ai-handler.js'))
|
||||
|
||||
// Create a window with the Prompt API enabled
|
||||
const win = new BrowserWindow({
|
||||
webPreferences: {
|
||||
enableBlinkFeatures: 'AIPromptAPI'
|
||||
}
|
||||
})
|
||||
|
||||
// Connect the AI handler to this session
|
||||
win.webContents.session.registerLocalAIHandler(aiHandler)
|
||||
|
||||
win.loadFile('index.html')
|
||||
})
|
||||
```
|
||||
|
||||
### 3. Use the Prompt API in your renderer
|
||||
|
||||
Your web content can now use the standard `LanguageModel` API, which is a
|
||||
global available in the renderer:
|
||||
|
||||
```html title='index.html (Renderer Process)'
|
||||
<script>
|
||||
async function askAI () {
|
||||
const model = await LanguageModel.create()
|
||||
const response = await model.prompt('What is Electron?')
|
||||
document.getElementById('response').textContent = response
|
||||
}
|
||||
</script>
|
||||
|
||||
<button onclick="askAI()">Ask AI</button>
|
||||
<p id="response"></p>
|
||||
```
|
||||
|
||||
## Implementing a real model
|
||||
|
||||
The quick-start example returns a hardcoded string. A real implementation
|
||||
would integrate with a local model. See [`electron/llm`](https://github.com/electron/llm)
|
||||
for an example of using `node-llama-cpp` to wire up GGUF (GPT-Generated Unified Format) models.
|
||||
|
||||
## Clearing the handler
|
||||
|
||||
To disconnect the AI handler from a session, pass `null`:
|
||||
|
||||
```js @ts-type={win:Electron.BrowserWindow}
|
||||
win.webContents.session.registerLocalAIHandler(null)
|
||||
```
|
||||
|
||||
After clearing, any `LanguageModel.create()` calls from renderers using that
|
||||
session will fail.
|
||||
|
||||
## Security considerations
|
||||
|
||||
The `details` object passed to your handler includes `webContentsId` and
|
||||
`securityOrigin`. Use these to decide whether to handle a request, and
|
||||
when to reuse a model instance versus providing a fresh instance to
|
||||
provide proper isolation between origins.
|
||||
|
||||
## Further reading
|
||||
|
||||
- [`localAIHandler` API reference](../api/local-ai-handler.md)
|
||||
- [`LanguageModel` API reference](../api/language-model.md)
|
||||
- [`ses.registerLocalAIHandler()`](../api/session.md#sesregisterlocalaihandlerhandler-experimental)
|
||||
- [`utilityProcess.fork()`](../api/utility-process.md#utilityprocessforkmodulepath-args-options)
|
||||
- [`electron/llm`](https://github.com/electron/llm)
|
||||
@@ -30,6 +30,8 @@ auto_filenames = {
|
||||
"docs/api/ipc-main-service-worker.md",
|
||||
"docs/api/ipc-main.md",
|
||||
"docs/api/ipc-renderer.md",
|
||||
"docs/api/language-model.md",
|
||||
"docs/api/local-ai-handler.md",
|
||||
"docs/api/menu-item.md",
|
||||
"docs/api/menu.md",
|
||||
"docs/api/message-channel-main.md",
|
||||
@@ -108,6 +110,14 @@ auto_filenames = {
|
||||
"docs/api/structures/jump-list-item.md",
|
||||
"docs/api/structures/keyboard-event.md",
|
||||
"docs/api/structures/keyboard-input-event.md",
|
||||
"docs/api/structures/language-model-append-options.md",
|
||||
"docs/api/structures/language-model-clone-options.md",
|
||||
"docs/api/structures/language-model-create-core-options.md",
|
||||
"docs/api/structures/language-model-create-options.md",
|
||||
"docs/api/structures/language-model-expected.md",
|
||||
"docs/api/structures/language-model-message-content.md",
|
||||
"docs/api/structures/language-model-message.md",
|
||||
"docs/api/structures/language-model-prompt-options.md",
|
||||
"docs/api/structures/media-access-permission-request.md",
|
||||
"docs/api/structures/memory-info.md",
|
||||
"docs/api/structures/memory-usage-details.md",
|
||||
@@ -400,6 +410,8 @@ auto_filenames = {
|
||||
"lib/common/init.ts",
|
||||
"lib/common/webpack-globals-provider.ts",
|
||||
"lib/utility/api/exports/electron.ts",
|
||||
"lib/utility/api/language-model.ts",
|
||||
"lib/utility/api/local-ai-handler.ts",
|
||||
"lib/utility/api/module-list.ts",
|
||||
"lib/utility/api/net.ts",
|
||||
"lib/utility/init.ts",
|
||||
|
||||
@@ -748,6 +748,8 @@ filenames = {
|
||||
"shell/services/node/node_service.h",
|
||||
"shell/services/node/parent_port.cc",
|
||||
"shell/services/node/parent_port.h",
|
||||
"shell/utility/api/electron_api_local_ai_handler.cc",
|
||||
"shell/utility/api/electron_api_local_ai_handler.h",
|
||||
"shell/utility/electron_content_utility_client.cc",
|
||||
"shell/utility/electron_content_utility_client.h",
|
||||
]
|
||||
|
||||
@@ -2,7 +2,7 @@ import { fetchWithSession } from '@electron/internal/browser/api/net-fetch';
|
||||
import { addIpcDispatchListeners } from '@electron/internal/browser/ipc-dispatch';
|
||||
import * as deprecate from '@electron/internal/common/deprecate';
|
||||
|
||||
import { net } from 'electron/main';
|
||||
import { net, type UtilityProcess } from 'electron/main';
|
||||
|
||||
const { fromPartition, fromPath, Session } = process._linkedBinding('electron_browser_session');
|
||||
const { isDisplayMediaSystemPickerAvailable } = process._linkedBinding('electron_browser_desktop_capturer');
|
||||
@@ -111,6 +111,12 @@ Session.prototype.removeExtension = deprecate.moveAPI(
|
||||
'session.extensions.removeExtension'
|
||||
);
|
||||
|
||||
Session.prototype.registerLocalAIHandler = function (handler: UtilityProcess | null) {
|
||||
// We need to unwrap the userland `ForkUtilityProcess` object and get the underlying
|
||||
// `ElectronInternal.UtilityProcessWrapper` before we call the C++ function
|
||||
return this._registerLocalAIHandler(handler !== null ? (handler as any)._unwrapHandle() : null);
|
||||
};
|
||||
|
||||
export default {
|
||||
fromPartition,
|
||||
fromPath,
|
||||
|
||||
@@ -131,6 +131,10 @@ class ForkUtilityProcess extends EventEmitter implements Electron.UtilityProcess
|
||||
return this.#stderr;
|
||||
}
|
||||
|
||||
_unwrapHandle () {
|
||||
return this.#handle;
|
||||
}
|
||||
|
||||
postMessage (message: any, transfer?: MessagePortMain[]) {
|
||||
if (Array.isArray(transfer)) {
|
||||
transfer = transfer.map((o: any) => o instanceof MessagePortMain ? o._internalPort : o);
|
||||
|
||||
44
lib/utility/api/language-model.ts
Normal file
44
lib/utility/api/language-model.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
interface LanguageModelConstructorValues {
|
||||
contextUsage: number;
|
||||
contextWindow: number;
|
||||
}
|
||||
|
||||
export default class LanguageModel implements Electron.LanguageModel {
|
||||
contextUsage: number;
|
||||
contextWindow: number;
|
||||
|
||||
constructor (values: LanguageModelConstructorValues) {
|
||||
this.contextUsage = values.contextUsage;
|
||||
this.contextWindow = values.contextWindow;
|
||||
}
|
||||
|
||||
static async create (): Promise<LanguageModel> {
|
||||
return new LanguageModel({
|
||||
contextUsage: 0,
|
||||
contextWindow: 0
|
||||
});
|
||||
}
|
||||
|
||||
static async availability () {
|
||||
return 'available';
|
||||
}
|
||||
|
||||
async prompt () {
|
||||
return '';
|
||||
}
|
||||
|
||||
async append (): Promise<undefined> {}
|
||||
|
||||
async measureContextUsage () {
|
||||
return 0;
|
||||
}
|
||||
|
||||
async clone () {
|
||||
return new LanguageModel({
|
||||
contextUsage: this.contextUsage,
|
||||
contextWindow: this.contextWindow
|
||||
});
|
||||
}
|
||||
|
||||
destroy () {}
|
||||
}
|
||||
3
lib/utility/api/local-ai-handler.ts
Normal file
3
lib/utility/api/local-ai-handler.ts
Normal file
@@ -0,0 +1,3 @@
|
||||
const binding = process._linkedBinding('electron_utility_local_ai_handler');
|
||||
|
||||
export const setPromptAPIHandler = binding.setPromptAPIHandler;
|
||||
@@ -1,5 +1,7 @@
|
||||
// Utility side modules, please sort alphabetically.
|
||||
export const utilityNodeModuleList: ElectronInternal.ModuleEntry[] = [
|
||||
{ name: 'localAIHandler', loader: () => require('./local-ai-handler') },
|
||||
{ name: 'LanguageModel', loader: () => require('./language-model') },
|
||||
{ name: 'net', loader: () => require('./net') },
|
||||
{ name: 'systemPreferences', loader: () => require('@electron/internal/browser/api/system-preferences') }
|
||||
];
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import LanguageModel from '@electron/internal/utility/api/language-model';
|
||||
import { ParentPort } from '@electron/internal/utility/parent-port';
|
||||
|
||||
import { EventEmitter } from 'events';
|
||||
import { ReadableStream } from 'stream/web';
|
||||
import { pathToFileURL } from 'url';
|
||||
|
||||
const v8Util = process._linkedBinding('electron_common_v8_util');
|
||||
@@ -10,6 +12,11 @@ const entryScript: string = v8Util.getHiddenValue(process, '_serviceStartupScrip
|
||||
// we need to restore it here.
|
||||
process.argv.splice(1, 1, entryScript);
|
||||
|
||||
// These are used by C++ to more easily identify these objects.
|
||||
v8Util.setHiddenValue(global, 'isReadableStream', (val: unknown) => val instanceof ReadableStream);
|
||||
v8Util.setHiddenValue(global, 'isLanguageModel', (val: unknown) => val instanceof LanguageModel);
|
||||
v8Util.setHiddenValue(global, 'isLanguageModelClass', (val: any) => Object.is(val, LanguageModel) || val?.prototype instanceof LanguageModel || false);
|
||||
|
||||
// Import common settings.
|
||||
require('@electron/internal/common/init');
|
||||
|
||||
|
||||
@@ -150,3 +150,4 @@ fix_use_fresh_lazynow_for_onendworkitemimpl_after_didruntask.patch
|
||||
fix_pulseaudio_stream_and_icon_names.patch
|
||||
fix_fire_menu_popup_start_for_dynamically_created_aria_menus.patch
|
||||
feat_allow_enabling_extensions_on_custom_protocols.patch
|
||||
reject_prompt_api_promises_on_mojo_connection_disconnect.patch
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: David Sanders <dsanders11@ucsbalum.com>
|
||||
Date: Wed, 1 Apr 2026 21:14:38 -0700
|
||||
Subject: Reject Prompt API promises on Mojo connection disconnect
|
||||
|
||||
Without these changes to reject promises when the Mojo connection
|
||||
disconnects, these promises will hang indefinitely if the Prompt
|
||||
API handler is killed or unregistered.
|
||||
|
||||
This will be upstreamed to Chromium.
|
||||
|
||||
Change-Id: I89a6a076ae35cbaf12a93c517223a524bab3dff0
|
||||
|
||||
diff --git a/third_party/blink/renderer/modules/ai/language_model.cc b/third_party/blink/renderer/modules/ai/language_model.cc
|
||||
index c176575f2dcc049e478d5388ae1934aa3bc59786..b8eda3dd8733b1e7e92f70a8c75678df4a0314c8 100644
|
||||
--- a/third_party/blink/renderer/modules/ai/language_model.cc
|
||||
+++ b/third_party/blink/renderer/modules/ai/language_model.cc
|
||||
@@ -190,6 +190,9 @@ class CloneLanguageModelClient
|
||||
client_remote;
|
||||
receiver_.Bind(client_remote.InitWithNewPipeAndPassReceiver(),
|
||||
language_model->GetTaskRunner());
|
||||
+ receiver_.set_disconnect_handler(
|
||||
+ BindOnce(&CloneLanguageModelClient::OnConnectionError,
|
||||
+ WrapWeakPersistent(this)));
|
||||
language_model_->GetAILanguageModelRemote()->Fork(std::move(client_remote));
|
||||
}
|
||||
~CloneLanguageModelClient() override = default;
|
||||
@@ -232,6 +235,11 @@ class CloneLanguageModelClient
|
||||
Cleanup();
|
||||
}
|
||||
|
||||
+ void OnConnectionError() {
|
||||
+ OnError(mojom::blink::AIManagerCreateClientError::kUnableToCreateSession,
|
||||
+ /*quota_error_info=*/nullptr);
|
||||
+ }
|
||||
+
|
||||
void ResetReceiver() override { receiver_.reset(); }
|
||||
|
||||
private:
|
||||
@@ -262,6 +270,8 @@ class AppendClient : public GarbageCollected<AppendClient>,
|
||||
mojo::PendingRemote<mojom::blink::ModelStreamingResponder> client_remote;
|
||||
receiver_.Bind(client_remote.InitWithNewPipeAndPassReceiver(),
|
||||
language_model->GetTaskRunner());
|
||||
+ receiver_.set_disconnect_handler(
|
||||
+ BindOnce(&AppendClient::OnConnectionError, WrapWeakPersistent(this)));
|
||||
language_model_->GetAILanguageModelRemote()->Append(
|
||||
std::move(prompts), std::move(client_remote));
|
||||
}
|
||||
@@ -317,6 +327,11 @@ class AppendClient : public GarbageCollected<AppendClient>,
|
||||
Cleanup();
|
||||
}
|
||||
|
||||
+ void OnConnectionError() {
|
||||
+ OnError(ModelStreamingResponseStatus::kErrorSessionDestroyed,
|
||||
+ /*quota_error_info=*/nullptr);
|
||||
+ }
|
||||
+
|
||||
void OnStreaming(const String& text) override {
|
||||
NOTREACHED() << "Append() should not invoke `OnStreaming()`";
|
||||
}
|
||||
@@ -761,6 +776,7 @@ void LanguageModel::ExecuteMeasureInputUsage(
|
||||
ScriptPromiseResolver<IDLDouble>* resolver,
|
||||
AbortSignal* signal,
|
||||
Vector<mojom::blink::AILanguageModelPromptPtr> prompts) {
|
||||
+ auto reject_fn = RejectOnDestruction(resolver, signal);
|
||||
language_model_remote_->MeasureInputUsage(
|
||||
std::move(prompts),
|
||||
BindOnce(
|
||||
@@ -783,7 +799,8 @@ void LanguageModel::ExecuteMeasureInputUsage(
|
||||
}
|
||||
resolver->Resolve(static_cast<double>(usage.value()));
|
||||
},
|
||||
- WrapPersistent(resolver), WrapPersistent(signal)));
|
||||
+ WrapPersistent(resolver), WrapPersistent(signal))
|
||||
+ .Then(std::move(reject_fn)));
|
||||
}
|
||||
|
||||
bool LanguageModel::ValidateInput(ScriptState* script_state,
|
||||
diff --git a/third_party/blink/renderer/modules/ai/language_model_create_client.cc b/third_party/blink/renderer/modules/ai/language_model_create_client.cc
|
||||
index ddc6fcda3ffbdc271bcdebfbd85aa711c063fee2..2b63e9e77dc1ed4a0ed9a687527efdc104974e7a 100644
|
||||
--- a/third_party/blink/renderer/modules/ai/language_model_create_client.cc
|
||||
+++ b/third_party/blink/renderer/modules/ai/language_model_create_client.cc
|
||||
@@ -509,6 +509,11 @@ void LanguageModelCreateClient::OnError(
|
||||
Cleanup();
|
||||
}
|
||||
|
||||
+void LanguageModelCreateClient::OnConnectionError() {
|
||||
+ OnError(mojom::blink::AIManagerCreateClientError::kUnableToCreateSession,
|
||||
+ /*quota_error_info=*/nullptr);
|
||||
+}
|
||||
+
|
||||
void LanguageModelCreateClient::ResetReceiver() {
|
||||
receiver_.reset();
|
||||
}
|
||||
@@ -524,6 +529,8 @@ void LanguageModelCreateClient::OnInitialPromptsResolved(
|
||||
mojo::PendingRemote<mojom::blink::AIManagerCreateLanguageModelClient>
|
||||
client_remote;
|
||||
receiver_.Bind(client_remote.InitWithNewPipeAndPassReceiver(), task_runner_);
|
||||
+ receiver_.set_disconnect_handler(
|
||||
+ BindOnce(&LanguageModelCreateClient::OnConnectionError, WrapWeakPersistent(this)));
|
||||
HeapMojoRemote<mojom::blink::AIManager>& ai_manager_remote =
|
||||
AIInterfaceProxy::GetAIManagerRemote(GetExecutionContext());
|
||||
|
||||
diff --git a/third_party/blink/renderer/modules/ai/language_model_create_client.h b/third_party/blink/renderer/modules/ai/language_model_create_client.h
|
||||
index 9ed8dfbefeccf1627d56f5ccc315f06071a63e25..7c8e823608883171c115676db151b70eb2fd055d 100644
|
||||
--- a/third_party/blink/renderer/modules/ai/language_model_create_client.h
|
||||
+++ b/third_party/blink/renderer/modules/ai/language_model_create_client.h
|
||||
@@ -49,6 +49,8 @@ class LanguageModelCreateClient
|
||||
// Process options and create, if the availability result is valid.
|
||||
void Create(mojom::blink::ModelAvailabilityCheckResult result);
|
||||
|
||||
+ void OnConnectionError();
|
||||
+
|
||||
// Continue creation after any initial prompts were processed or rejected.
|
||||
void OnInitialPromptsResolved(
|
||||
Vector<mojom::blink::AILanguageModelExpectedPtr> expected_inputs,
|
||||
diff --git a/third_party/blink/renderer/modules/ai/model_execution_responder.cc b/third_party/blink/renderer/modules/ai/model_execution_responder.cc
|
||||
index 47b65b13adfab4b8f2597a23d38a386915643d1b..fa9b54e1069019a66b8dab6eb0efe5df8b34c11a 100644
|
||||
--- a/third_party/blink/renderer/modules/ai/model_execution_responder.cc
|
||||
+++ b/third_party/blink/renderer/modules/ai/model_execution_responder.cc
|
||||
@@ -84,7 +84,10 @@ class Responder final : public GarbageCollected<Responder>,
|
||||
mojo::PendingRemote<blink::mojom::blink::ModelStreamingResponder>
|
||||
BindNewPipeAndPassRemote(
|
||||
scoped_refptr<base::SequencedTaskRunner> task_runner) {
|
||||
- return receiver_.BindNewPipeAndPassRemote(task_runner);
|
||||
+ auto pending_remote = receiver_.BindNewPipeAndPassRemote(task_runner);
|
||||
+ receiver_.set_disconnect_handler(
|
||||
+ BindOnce(&Responder::OnConnectionError, WrapWeakPersistent(this)));
|
||||
+ return pending_remote;
|
||||
}
|
||||
|
||||
// `mojom::blink::ModelStreamingResponder` implementation.
|
||||
@@ -144,6 +147,11 @@ class Responder final : public GarbageCollected<Responder>,
|
||||
Cleanup();
|
||||
}
|
||||
|
||||
+ void OnConnectionError() {
|
||||
+ OnError(ModelStreamingResponseStatus::kErrorSessionDestroyed,
|
||||
+ /*quota_error_info=*/nullptr);
|
||||
+ }
|
||||
+
|
||||
void RecordResponseStatusMetrics(
|
||||
mojom::blink::ModelStreamingResponseStatus status) {
|
||||
base::UmaHistogramEnumeration(
|
||||
@@ -235,7 +243,10 @@ class StreamingResponder final
|
||||
mojo::PendingRemote<blink::mojom::blink::ModelStreamingResponder>
|
||||
BindNewPipeAndPassRemote(
|
||||
scoped_refptr<base::SequencedTaskRunner> task_runner) {
|
||||
- return receiver_.BindNewPipeAndPassRemote(task_runner);
|
||||
+ auto pending_remote = receiver_.BindNewPipeAndPassRemote(task_runner);
|
||||
+ receiver_.set_disconnect_handler(BindOnce(
|
||||
+ &StreamingResponder::OnConnectionError, WrapWeakPersistent(this)));
|
||||
+ return pending_remote;
|
||||
}
|
||||
|
||||
ReadableStream* CreateReadableStream() {
|
||||
@@ -337,6 +348,11 @@ class StreamingResponder final
|
||||
Cleanup();
|
||||
}
|
||||
|
||||
+ void OnConnectionError() {
|
||||
+ OnError(ModelStreamingResponseStatus::kErrorSessionDestroyed,
|
||||
+ /*quota_error_info=*/nullptr);
|
||||
+ }
|
||||
+
|
||||
void RecordResponseStatusMetrics(
|
||||
mojom::blink::ModelStreamingResponseStatus status) {
|
||||
base::UmaHistogramEnumeration(
|
||||
187
shell/browser/ai/proxying_ai_manager.cc
Normal file
187
shell/browser/ai/proxying_ai_manager.cc
Normal file
@@ -0,0 +1,187 @@
|
||||
// Copyright (c) 2025 Microsoft, Inc.
|
||||
// Use of this source code is governed by the MIT license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#include "shell/browser/ai/proxying_ai_manager.h"
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "base/functional/bind.h"
|
||||
#include "base/notimplemented.h"
|
||||
#include "content/public/browser/browser_context.h"
|
||||
#include "content/public/browser/render_frame_host.h"
|
||||
#include "content/public/browser/weak_document_ptr.h"
|
||||
#include "mojo/public/cpp/bindings/callback_helpers.h"
|
||||
#include "shell/browser/api/electron_api_session.h"
|
||||
#include "shell/browser/api/electron_api_web_contents.h"
|
||||
#include "shell/browser/session_preferences.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_writer.mojom.h"
|
||||
|
||||
namespace electron {
|
||||
|
||||
ProxyingAIManager::ProxyingAIManager(content::BrowserContext* browser_context,
|
||||
content::RenderFrameHost* rfh)
|
||||
: browser_context_(browser_context),
|
||||
rfh_(rfh ? rfh->GetWeakDocumentPtr() : content::WeakDocumentPtr()) {
|
||||
auto* session_prefs =
|
||||
SessionPreferences::FromBrowserContext(browser_context_);
|
||||
if (session_prefs) {
|
||||
ai_handler_changed_subscription_ =
|
||||
session_prefs->AddAIHandlerChangedCallback(
|
||||
base::BindRepeating(&ProxyingAIManager::OnAIHandlerChanged,
|
||||
weak_ptr_factory_.GetWeakPtr()));
|
||||
}
|
||||
}
|
||||
|
||||
ProxyingAIManager::~ProxyingAIManager() = default;
|
||||
|
||||
void ProxyingAIManager::OnAIHandlerChanged() {
|
||||
ai_manager_remote_.reset();
|
||||
}
|
||||
|
||||
void ProxyingAIManager::AddReceiver(
|
||||
mojo::PendingReceiver<blink::mojom::AIManager> receiver) {
|
||||
receivers_.Add(this, std::move(receiver));
|
||||
}
|
||||
|
||||
const mojo::Remote<blink::mojom::AIManager>&
|
||||
ProxyingAIManager::GetAIManagerRemote(const SessionPreferences& session_prefs) {
|
||||
if (!ai_manager_remote_.is_bound()) {
|
||||
auto* local_ai_handler = session_prefs.GetLocalAIHandler().get();
|
||||
|
||||
if (local_ai_handler) {
|
||||
auto* rfh = rfh_.AsRenderFrameHostIfValid();
|
||||
DCHECK(rfh);
|
||||
|
||||
auto* web_contents = electron::api::WebContents::From(
|
||||
content::WebContents::FromRenderFrameHost(rfh));
|
||||
std::optional<int32_t> web_contents_id;
|
||||
|
||||
if (web_contents) {
|
||||
web_contents_id = web_contents->ID();
|
||||
}
|
||||
|
||||
local_ai_handler->BindAIManager(
|
||||
web_contents_id, rfh->GetLastCommittedOrigin(),
|
||||
ai_manager_remote_.BindNewPipeAndPassReceiver());
|
||||
}
|
||||
}
|
||||
|
||||
return ai_manager_remote_;
|
||||
}
|
||||
|
||||
void ProxyingAIManager::CanCreateLanguageModel(
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options,
|
||||
CanCreateLanguageModelCallback callback) {
|
||||
auto* session_prefs =
|
||||
SessionPreferences::FromBrowserContext(browser_context_);
|
||||
DCHECK(session_prefs);
|
||||
|
||||
// Default to unavailable. This ensures the callback is always invoked
|
||||
// even if there is no registered utility process handler, or the
|
||||
// process crashes.
|
||||
auto cb = mojo::WrapCallbackWithDefaultInvokeIfNotRun(
|
||||
std::move(callback),
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
|
||||
|
||||
// Proxy the call through to the utility process
|
||||
auto& ai_manager = GetAIManagerRemote(*session_prefs);
|
||||
|
||||
if (ai_manager.is_bound()) {
|
||||
ai_manager->CanCreateLanguageModel(std::move(options), std::move(cb));
|
||||
}
|
||||
}
|
||||
|
||||
void ProxyingAIManager::CreateLanguageModel(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client,
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options) {
|
||||
auto* session_prefs =
|
||||
SessionPreferences::FromBrowserContext(browser_context_);
|
||||
DCHECK(session_prefs);
|
||||
|
||||
// Proxy the call through to the utility process
|
||||
auto& ai_manager = GetAIManagerRemote(*session_prefs);
|
||||
|
||||
if (!ai_manager.is_bound()) {
|
||||
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client_remote(std::move(client));
|
||||
client_remote->OnError(
|
||||
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession,
|
||||
/*quota_error_info=*/nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
ai_manager->CreateLanguageModel(std::move(client), std::move(options));
|
||||
}
|
||||
|
||||
void ProxyingAIManager::CanCreateSummarizer(
|
||||
blink::mojom::AISummarizerCreateOptionsPtr options,
|
||||
CanCreateSummarizerCallback callback) {
|
||||
std::move(callback).Run(
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
|
||||
}
|
||||
|
||||
void ProxyingAIManager::CreateSummarizer(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
|
||||
blink::mojom::AISummarizerCreateOptionsPtr options) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
void ProxyingAIManager::GetLanguageModelParams(
|
||||
GetLanguageModelParamsCallback callback) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
void ProxyingAIManager::CanCreateWriter(
|
||||
blink::mojom::AIWriterCreateOptionsPtr options,
|
||||
CanCreateWriterCallback callback) {
|
||||
std::move(callback).Run(
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
|
||||
}
|
||||
|
||||
void ProxyingAIManager::CreateWriter(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
|
||||
blink::mojom::AIWriterCreateOptionsPtr options) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
void ProxyingAIManager::CanCreateRewriter(
|
||||
blink::mojom::AIRewriterCreateOptionsPtr options,
|
||||
CanCreateRewriterCallback callback) {
|
||||
std::move(callback).Run(
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
|
||||
}
|
||||
|
||||
void ProxyingAIManager::CreateRewriter(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
|
||||
blink::mojom::AIRewriterCreateOptionsPtr options) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
void ProxyingAIManager::CanCreateProofreader(
|
||||
blink::mojom::AIProofreaderCreateOptionsPtr options,
|
||||
CanCreateProofreaderCallback callback) {
|
||||
std::move(callback).Run(
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
|
||||
}
|
||||
|
||||
void ProxyingAIManager::CreateProofreader(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient> client,
|
||||
blink::mojom::AIProofreaderCreateOptionsPtr options) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
void ProxyingAIManager::AddModelDownloadProgressObserver(
|
||||
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
|
||||
observer_remote) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
} // namespace electron
|
||||
103
shell/browser/ai/proxying_ai_manager.h
Normal file
103
shell/browser/ai/proxying_ai_manager.h
Normal file
@@ -0,0 +1,103 @@
|
||||
// Copyright (c) 2025 Microsoft, Inc.
|
||||
// Use of this source code is governed by the MIT license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#ifndef ELECTRON_SHELL_BROWSER_AI_PROXYING_AI_MANAGER_H_
|
||||
#define ELECTRON_SHELL_BROWSER_AI_PROXYING_AI_MANAGER_H_
|
||||
|
||||
#include "base/callback_list.h"
|
||||
#include "base/memory/weak_ptr.h"
|
||||
#include "base/supports_user_data.h"
|
||||
#include "content/public/browser/weak_document_ptr.h"
|
||||
#include "mojo/public/cpp/bindings/pending_receiver.h"
|
||||
#include "mojo/public/cpp/bindings/pending_remote.h"
|
||||
#include "mojo/public/cpp/bindings/receiver_set.h"
|
||||
#include "shell/browser/session_preferences.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom-forward.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom-forward.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom-forward.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_writer.mojom-forward.h"
|
||||
|
||||
namespace content {
|
||||
class BrowserContext;
|
||||
class RenderFrameHost;
|
||||
} // namespace content
|
||||
|
||||
namespace electron {
|
||||
|
||||
// Owned by the host of the document / service worker via `SupportUserData`.
|
||||
// The browser-side implementation of `blink::mojom::AIManager`, which
|
||||
// proxies requests to a utility process if the session has a registered
|
||||
// handler.
|
||||
class ProxyingAIManager : public base::SupportsUserData::Data,
|
||||
public blink::mojom::AIManager {
|
||||
public:
|
||||
ProxyingAIManager(content::BrowserContext* browser_context,
|
||||
content::RenderFrameHost* rfh);
|
||||
ProxyingAIManager(const ProxyingAIManager&) = delete;
|
||||
ProxyingAIManager& operator=(const ProxyingAIManager&) = delete;
|
||||
|
||||
~ProxyingAIManager() override;
|
||||
|
||||
void AddReceiver(mojo::PendingReceiver<blink::mojom::AIManager> receiver);
|
||||
|
||||
private:
|
||||
// Lazily bind the AIManager remote so that the developer can
|
||||
// set the local AI handler after this class is already created
|
||||
[[nodiscard]] const mojo::Remote<blink::mojom::AIManager>& GetAIManagerRemote(
|
||||
const SessionPreferences& session_prefs);
|
||||
|
||||
void OnAIHandlerChanged();
|
||||
|
||||
// `blink::mojom::AIManager` implementation.
|
||||
void CanCreateLanguageModel(
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options,
|
||||
CanCreateLanguageModelCallback callback) override;
|
||||
void CreateLanguageModel(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client,
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options) override;
|
||||
void CanCreateSummarizer(blink::mojom::AISummarizerCreateOptionsPtr options,
|
||||
CanCreateSummarizerCallback callback) override;
|
||||
void CreateSummarizer(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
|
||||
blink::mojom::AISummarizerCreateOptionsPtr options) override;
|
||||
void GetLanguageModelParams(GetLanguageModelParamsCallback callback) override;
|
||||
void CanCreateWriter(blink::mojom::AIWriterCreateOptionsPtr options,
|
||||
CanCreateWriterCallback callback) override;
|
||||
void CreateWriter(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
|
||||
blink::mojom::AIWriterCreateOptionsPtr options) override;
|
||||
void CanCreateRewriter(blink::mojom::AIRewriterCreateOptionsPtr options,
|
||||
CanCreateRewriterCallback callback) override;
|
||||
void CreateRewriter(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
|
||||
blink::mojom::AIRewriterCreateOptionsPtr options) override;
|
||||
void CanCreateProofreader(blink::mojom::AIProofreaderCreateOptionsPtr options,
|
||||
CanCreateProofreaderCallback callback) override;
|
||||
void CreateProofreader(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient>
|
||||
client,
|
||||
blink::mojom::AIProofreaderCreateOptionsPtr options) override;
|
||||
void AddModelDownloadProgressObserver(
|
||||
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
|
||||
observer_remote) override;
|
||||
|
||||
mojo::ReceiverSet<blink::mojom::AIManager> receivers_;
|
||||
|
||||
raw_ptr<content::BrowserContext> browser_context_;
|
||||
|
||||
content::WeakDocumentPtr rfh_;
|
||||
|
||||
mojo::Remote<blink::mojom::AIManager> ai_manager_remote_;
|
||||
|
||||
base::CallbackListSubscription ai_handler_changed_subscription_;
|
||||
|
||||
base::WeakPtrFactory<ProxyingAIManager> weak_ptr_factory_{this};
|
||||
};
|
||||
|
||||
} // namespace electron
|
||||
|
||||
#endif // ELECTRON_SHELL_BROWSER_AI_PROXYING_AI_MANAGER_H_
|
||||
@@ -1555,6 +1555,26 @@ v8::Local<v8::Value> Session::ClearData(gin::Arguments* const args) {
|
||||
return promise_handle;
|
||||
}
|
||||
|
||||
void Session::RegisterLocalAIHandler(gin_helper::ErrorThrower thrower,
|
||||
v8::Local<v8::Value> val) {
|
||||
auto* isolate = JavascriptEnvironment::GetIsolate();
|
||||
gin_helper::Handle<UtilityProcessWrapper> handler;
|
||||
|
||||
if (!(val->IsNull() || gin::ConvertFromV8(isolate, val, &handler))) {
|
||||
thrower.ThrowTypeError("Must pass null or UtilityProcess");
|
||||
return;
|
||||
}
|
||||
|
||||
auto* prefs = SessionPreferences::FromBrowserContext(browser_context());
|
||||
DCHECK(prefs);
|
||||
|
||||
if (!handler.IsEmpty()) {
|
||||
prefs->SetLocalAIHandler(handler->GetWeakPtr());
|
||||
} else {
|
||||
prefs->SetLocalAIHandler(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
#if BUILDFLAG(ENABLE_BUILTIN_SPELLCHECKER)
|
||||
base::Value Session::GetSpellCheckerLanguages() {
|
||||
return browser_context_->prefs()
|
||||
@@ -1841,6 +1861,7 @@ void Session::FillObjectTemplate(v8::Isolate* isolate,
|
||||
.SetMethod("setCodeCachePath", &Session::SetCodeCachePath)
|
||||
.SetMethod("clearCodeCaches", &Session::ClearCodeCaches)
|
||||
.SetMethod("clearData", &Session::ClearData)
|
||||
.SetMethod("_registerLocalAIHandler", &Session::RegisterLocalAIHandler)
|
||||
.SetProperty("cookies", &Session::Cookies)
|
||||
.SetProperty("extensions", &Session::Extensions)
|
||||
.SetProperty("netLog", &Session::NetLog)
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "gin/wrappable.h"
|
||||
#include "services/network/public/mojom/host_resolver.mojom-forward.h"
|
||||
#include "services/network/public/mojom/ssl_config.mojom-forward.h"
|
||||
#include "shell/browser/api/electron_api_utility_process.h"
|
||||
#include "shell/browser/api/ipc_dispatcher.h"
|
||||
#include "shell/browser/event_emitter_mixin.h"
|
||||
#include "shell/browser/net/resolve_proxy_helper.h"
|
||||
@@ -178,6 +179,8 @@ class Session final : public gin::Wrappable<Session>,
|
||||
void SetCodeCachePath(gin::Arguments* args);
|
||||
v8::Local<v8::Promise> ClearCodeCaches(const gin_helper::Dictionary& options);
|
||||
v8::Local<v8::Value> ClearData(gin::Arguments* args);
|
||||
void RegisterLocalAIHandler(gin_helper::ErrorThrower thrower,
|
||||
v8::Local<v8::Value> val);
|
||||
#if BUILDFLAG(ENABLE_BUILTIN_SPELLCHECKER)
|
||||
base::Value GetSpellCheckerLanguages();
|
||||
void SetSpellCheckerLanguages(gin_helper::ErrorThrower thrower,
|
||||
|
||||
@@ -47,6 +47,10 @@
|
||||
#include "base/win/windows_types.h"
|
||||
#endif
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
namespace electron {
|
||||
|
||||
namespace {
|
||||
@@ -454,6 +458,19 @@ UtilityProcessWrapper::CreateURLLoaderFactoryParams() {
|
||||
return params;
|
||||
}
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
void UtilityProcessWrapper::BindAIManager(
|
||||
std::optional<int32_t> web_contents_id,
|
||||
const url::Origin& security_origin,
|
||||
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager) {
|
||||
auto params = node::mojom::BindAIManagerParams::New();
|
||||
params->web_contents_id = web_contents_id;
|
||||
params->security_origin = security_origin;
|
||||
|
||||
node_service_remote_->BindAIManager(std::move(params), std::move(ai_manager));
|
||||
}
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
// static
|
||||
raw_ptr<UtilityProcessWrapper> UtilityProcessWrapper::FromProcessId(
|
||||
base::ProcessId pid) {
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "base/memory/weak_ptr.h"
|
||||
#include "base/process/process_handle.h"
|
||||
#include "content/public/browser/service_process_host.h"
|
||||
#include "electron/buildflags/buildflags.h"
|
||||
#include "mojo/public/cpp/bindings/message.h"
|
||||
#include "mojo/public/cpp/bindings/remote.h"
|
||||
#include "shell/browser/event_emitter_mixin.h"
|
||||
@@ -24,6 +25,10 @@
|
||||
#include "shell/services/node/public/mojom/node_service.mojom.h"
|
||||
#include "v8/include/v8-forward.h"
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
namespace gin {
|
||||
class Arguments;
|
||||
} // namespace gin
|
||||
@@ -58,6 +63,16 @@ class UtilityProcessWrapper final
|
||||
static gin_helper::Handle<UtilityProcessWrapper> Create(gin::Arguments* args);
|
||||
static raw_ptr<UtilityProcessWrapper> FromProcessId(base::ProcessId pid);
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
void BindAIManager(std::optional<int32_t> web_contents_id,
|
||||
const url::Origin& security_origin,
|
||||
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager);
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
base::WeakPtr<UtilityProcessWrapper> GetWeakPtr() {
|
||||
return weak_factory_.GetWeakPtr();
|
||||
}
|
||||
|
||||
void Shutdown(uint32_t exit_code);
|
||||
|
||||
// gin_helper::Wrappable
|
||||
|
||||
@@ -230,12 +230,20 @@
|
||||
#include "ui/webui/resources/cr_components/help_bubble/help_bubble.mojom.h" // nogncheck
|
||||
#endif
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
#include "shell/browser/ai/proxying_ai_manager.h"
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
using content::BrowserThread;
|
||||
|
||||
namespace electron {
|
||||
|
||||
namespace {
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
const char kAIManagerUserDataKey[] = "ai_manager";
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
ElectronBrowserClient* g_browser_client = nullptr;
|
||||
|
||||
base::NoDestructor<std::string> g_io_thread_application_locale;
|
||||
@@ -1580,6 +1588,26 @@ void ElectronBrowserClient::
|
||||
#endif
|
||||
}
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
// Refs
|
||||
// https://source.chromium.org/chromium/chromium/src/+/main:chrome/browser/chrome_content_browser_client.cc;l=8724-8737;drc=74754be9d4550a487df006a51a33318245d37301
|
||||
void ElectronBrowserClient::BindAIManager(
|
||||
content::BrowserContext* browser_context,
|
||||
base::SupportsUserData* context_user_data,
|
||||
content::RenderFrameHost* rfh,
|
||||
mojo::PendingReceiver<blink::mojom::AIManager> receiver) {
|
||||
if (!context_user_data->GetUserData(kAIManagerUserDataKey)) {
|
||||
context_user_data->SetUserData(
|
||||
kAIManagerUserDataKey,
|
||||
std::make_unique<ProxyingAIManager>(browser_context, rfh));
|
||||
}
|
||||
|
||||
ProxyingAIManager* ai_manager = static_cast<ProxyingAIManager*>(
|
||||
context_user_data->GetUserData(kAIManagerUserDataKey));
|
||||
ai_manager->AddReceiver(std::move(receiver));
|
||||
}
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
std::string ElectronBrowserClient::GetApplicationLocale() {
|
||||
return BrowserThread::CurrentlyOn(BrowserThread::IO)
|
||||
? *g_io_thread_application_locale
|
||||
|
||||
@@ -274,6 +274,14 @@ class ElectronBrowserClient : public content::ContentBrowserClient,
|
||||
const content::ServiceWorkerVersionBaseInfo& service_worker_version_info,
|
||||
blink::AssociatedInterfaceRegistry& associated_registry) override;
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
void BindAIManager(
|
||||
content::BrowserContext* browser_context,
|
||||
base::SupportsUserData* context_user_data,
|
||||
content::RenderFrameHost* rfh,
|
||||
mojo::PendingReceiver<blink::mojom::AIManager> receiver) override;
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
bool HandleExternalProtocol(
|
||||
const GURL& url,
|
||||
content::WebContents::Getter web_contents_getter,
|
||||
|
||||
@@ -30,6 +30,11 @@ SessionPreferences* SessionPreferences::FromBrowserContext(
|
||||
return static_cast<SessionPreferences*>(context->GetUserData(&kLocatorKey));
|
||||
}
|
||||
|
||||
base::CallbackListSubscription SessionPreferences::AddAIHandlerChangedCallback(
|
||||
base::RepeatingClosure callback) {
|
||||
return ai_handler_changed_callbacks_.Add(std::move(callback));
|
||||
}
|
||||
|
||||
bool SessionPreferences::HasServiceWorkerPreloadScript() {
|
||||
const auto& preloads = preload_scripts();
|
||||
auto it = std::find_if(
|
||||
|
||||
@@ -7,8 +7,11 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "base/callback_list.h"
|
||||
#include "base/files/file_path.h"
|
||||
#include "base/memory/weak_ptr.h"
|
||||
#include "base/supports_user_data.h"
|
||||
#include "shell/browser/api/electron_api_utility_process.h"
|
||||
#include "shell/browser/preload_script.h"
|
||||
|
||||
namespace content {
|
||||
@@ -17,6 +20,10 @@ class BrowserContext;
|
||||
|
||||
namespace electron {
|
||||
|
||||
namespace api {
|
||||
class UtilityProcessWrapper;
|
||||
}
|
||||
|
||||
class SessionPreferences : public base::SupportsUserData::Data {
|
||||
public:
|
||||
static SessionPreferences* FromBrowserContext(
|
||||
@@ -30,6 +37,18 @@ class SessionPreferences : public base::SupportsUserData::Data {
|
||||
|
||||
bool HasServiceWorkerPreloadScript();
|
||||
|
||||
const base::WeakPtr<api::UtilityProcessWrapper>& GetLocalAIHandler() const {
|
||||
return local_ai_handler_;
|
||||
}
|
||||
|
||||
void SetLocalAIHandler(base::WeakPtr<api::UtilityProcessWrapper> handler) {
|
||||
local_ai_handler_ = handler;
|
||||
ai_handler_changed_callbacks_.Notify();
|
||||
}
|
||||
|
||||
base::CallbackListSubscription AddAIHandlerChangedCallback(
|
||||
base::RepeatingClosure callback);
|
||||
|
||||
private:
|
||||
SessionPreferences();
|
||||
|
||||
@@ -37,6 +56,8 @@ class SessionPreferences : public base::SupportsUserData::Data {
|
||||
static int kLocatorKey;
|
||||
|
||||
std::vector<PreloadScript> preload_scripts_;
|
||||
base::WeakPtr<api::UtilityProcessWrapper> local_ai_handler_;
|
||||
base::RepeatingClosureList ai_handler_changed_callbacks_;
|
||||
};
|
||||
|
||||
} // namespace electron
|
||||
|
||||
@@ -25,6 +25,10 @@ bool IsPrintingEnabled() {
|
||||
return BUILDFLAG(ENABLE_PRINTING);
|
||||
}
|
||||
|
||||
bool IsPromptAPIEnabled() {
|
||||
return BUILDFLAG(ENABLE_PROMPT_API);
|
||||
}
|
||||
|
||||
bool IsExtensionsEnabled() {
|
||||
return BUILDFLAG(ENABLE_ELECTRON_EXTENSIONS);
|
||||
}
|
||||
@@ -48,6 +52,7 @@ void Initialize(v8::Local<v8::Object> exports,
|
||||
dict.SetMethod("isFakeLocationProviderEnabled",
|
||||
&IsFakeLocationProviderEnabled);
|
||||
dict.SetMethod("isPrintingEnabled", &IsPrintingEnabled);
|
||||
dict.SetMethod("isPromptAPIEnabled", &IsPromptAPIEnabled);
|
||||
dict.SetMethod("isComponentBuild", &IsComponentBuild);
|
||||
dict.SetMethod("isExtensionsEnabled", &IsExtensionsEnabled);
|
||||
}
|
||||
|
||||
@@ -56,6 +56,16 @@ v8::Local<v8::Value> CustomEmit(v8::Isolate* isolate,
|
||||
converted_args));
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
v8::Local<v8::Value> CallMethod(v8::Isolate* isolate,
|
||||
v8::Local<v8::Object> object,
|
||||
const char* method_name,
|
||||
Args&&... args) {
|
||||
v8::EscapableHandleScope scope(isolate);
|
||||
return scope.Escape(
|
||||
CustomEmit(isolate, object, method_name, std::forward<Args>(args)...));
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
v8::Local<v8::Value> CallMethod(v8::Isolate* isolate,
|
||||
gin_helper::DeprecatedWrappable<T>* object,
|
||||
|
||||
@@ -116,6 +116,7 @@
|
||||
V(electron_browser_event_emitter) \
|
||||
V(electron_browser_system_preferences) \
|
||||
V(electron_common_net) \
|
||||
V(electron_utility_local_ai_handler) \
|
||||
V(electron_utility_parent_port)
|
||||
|
||||
#define ELECTRON_TESTING_BINDINGS(V) V(electron_common_testing)
|
||||
|
||||
@@ -151,6 +151,22 @@ node::Environment* CreateEnvironment(v8::Isolate* isolate,
|
||||
return env;
|
||||
}
|
||||
|
||||
v8::Local<v8::Object> CreateAbortController(v8::Isolate* isolate) {
|
||||
auto context = isolate->GetCurrentContext();
|
||||
auto global_object = context->Global();
|
||||
|
||||
auto value =
|
||||
global_object->Get(context, gin::StringToV8(isolate, "AbortController"))
|
||||
.ToLocalChecked();
|
||||
DCHECK(!value.IsEmpty() && value->IsObject());
|
||||
|
||||
DCHECK(value->IsFunction());
|
||||
auto constructor = value.As<v8::Function>();
|
||||
auto instance =
|
||||
constructor->NewInstance(context, 0, nullptr).ToLocalChecked();
|
||||
return instance;
|
||||
}
|
||||
|
||||
ExplicitMicrotasksScope::ExplicitMicrotasksScope(v8::MicrotaskQueue* queue)
|
||||
: microtask_queue_(queue), original_policy_(queue->microtasks_policy()) {
|
||||
// In browser-like processes, some nested run loops (macOS usually) may
|
||||
|
||||
@@ -66,6 +66,8 @@ node::Environment* CreateEnvironment(v8::Isolate* isolate,
|
||||
node::EnvironmentFlags::Flags env_flags,
|
||||
std::string_view process_type = "");
|
||||
|
||||
v8::Local<v8::Object> CreateAbortController(v8::Isolate* isolate);
|
||||
|
||||
// A scope that temporarily changes the microtask policy to explicit. Use this
|
||||
// anywhere that can trigger Node.js or uv_run().
|
||||
//
|
||||
|
||||
@@ -11,8 +11,10 @@
|
||||
#include "base/no_destructor.h"
|
||||
#include "base/process/process.h"
|
||||
#include "base/strings/utf_string_conversions.h"
|
||||
#include "electron/buildflags/buildflags.h"
|
||||
#include "electron/fuses.h"
|
||||
#include "electron/mas.h"
|
||||
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
|
||||
#include "net/base/network_change_notifier.h"
|
||||
#include "services/network/public/cpp/wrapper_shared_url_loader_factory.h"
|
||||
#include "services/network/public/mojom/host_resolver.mojom.h"
|
||||
@@ -30,6 +32,12 @@
|
||||
#include "shell/common/crash_keys.h"
|
||||
#endif
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
#include "shell/utility/ai/utility_ai_manager.h"
|
||||
#include "url/gurl.h"
|
||||
#include "url/origin.h"
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
namespace electron {
|
||||
|
||||
mojo::Remote<node::mojom::NodeServiceClient>& GetRemote() {
|
||||
@@ -215,4 +223,15 @@ void NodeService::UpdateURLLoaderFactory(
|
||||
params->use_network_observer_from_url_loader_factory);
|
||||
}
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
void NodeService::BindAIManager(
|
||||
node::mojom::BindAIManagerParamsPtr params,
|
||||
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager) {
|
||||
mojo::MakeSelfOwnedReceiver(
|
||||
std::make_unique<UtilityAIManager>(params->web_contents_id,
|
||||
params->security_origin),
|
||||
std::move(ai_manager));
|
||||
}
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
} // namespace electron
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "electron/buildflags/buildflags.h"
|
||||
#include "mojo/public/cpp/bindings/pending_receiver.h"
|
||||
#include "mojo/public/cpp/bindings/pending_remote.h"
|
||||
#include "mojo/public/cpp/bindings/receiver.h"
|
||||
@@ -68,6 +69,12 @@ class NodeService : public node::mojom::NodeService {
|
||||
void UpdateURLLoaderFactory(
|
||||
node::mojom::URLLoaderFactoryParamsPtr params) override;
|
||||
|
||||
#if BUILDFLAG(ENABLE_PROMPT_API)
|
||||
void BindAIManager(
|
||||
node::mojom::BindAIManagerParamsPtr params,
|
||||
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager) override;
|
||||
#endif // BUILDFLAG(ENABLE_PROMPT_API)
|
||||
|
||||
private:
|
||||
// This needs to be initialized first so that it can be destroyed last
|
||||
// after the node::Environment is destroyed. This ensures that if
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Use of this source code is governed by the MIT license that can be
|
||||
# found in the LICENSE file.
|
||||
|
||||
import("//electron/buildflags/buildflags.gni")
|
||||
import("//mojo/public/tools/bindings/mojom.gni")
|
||||
|
||||
mojom("mojom") {
|
||||
@@ -11,4 +12,8 @@ mojom("mojom") {
|
||||
"//sandbox/policy/mojom",
|
||||
"//third_party/blink/public/mojom:mojom_core",
|
||||
]
|
||||
|
||||
if (enable_prompt_api) {
|
||||
enabled_features = [ "enable_prompt_api" ]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,9 @@ import "mojo/public/mojom/base/file_path.mojom";
|
||||
import "sandbox/policy/mojom/sandbox.mojom";
|
||||
import "services/network/public/mojom/host_resolver.mojom";
|
||||
import "services/network/public/mojom/url_loader_factory.mojom";
|
||||
import "third_party/blink/public/mojom/ai/ai_manager.mojom";
|
||||
import "third_party/blink/public/mojom/messaging/message_port_descriptor.mojom";
|
||||
import "url/mojom/origin.mojom";
|
||||
|
||||
struct URLLoaderFactoryParams {
|
||||
pending_remote<network.mojom.URLLoaderFactory> url_loader_factory;
|
||||
@@ -24,6 +26,11 @@ struct NodeServiceParams {
|
||||
URLLoaderFactoryParams url_loader_factory_params;
|
||||
};
|
||||
|
||||
struct BindAIManagerParams {
|
||||
int32? web_contents_id;
|
||||
url.mojom.Origin security_origin;
|
||||
};
|
||||
|
||||
interface NodeServiceClient {
|
||||
OnV8FatalError(string location, string report);
|
||||
};
|
||||
@@ -34,4 +41,8 @@ interface NodeService {
|
||||
pending_remote<NodeServiceClient> client_remote);
|
||||
|
||||
UpdateURLLoaderFactory(URLLoaderFactoryParams params);
|
||||
|
||||
[EnableIf=enable_prompt_api]
|
||||
BindAIManager(BindAIManagerParams params,
|
||||
pending_receiver<blink.mojom.AIManager> ai_manager);
|
||||
};
|
||||
|
||||
726
shell/utility/ai/utility_ai_language_model.cc
Normal file
726
shell/utility/ai/utility_ai_language_model.cc
Normal file
@@ -0,0 +1,726 @@
|
||||
// Copyright (c) 2025 Microsoft, Inc.
|
||||
// Use of this source code is governed by the MIT license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#include "shell/utility/ai/utility_ai_language_model.h"
|
||||
|
||||
#include <string_view>
|
||||
|
||||
#include "base/no_destructor.h"
|
||||
#include "base/notimplemented.h"
|
||||
#include "shell/browser/javascript_environment.h"
|
||||
#include "shell/common/gin_converters/callback_converter.h"
|
||||
#include "shell/common/gin_converters/std_converter.h"
|
||||
#include "shell/common/gin_helper/dictionary.h"
|
||||
#include "shell/common/gin_helper/event_emitter_caller.h"
|
||||
#include "shell/common/node_includes.h"
|
||||
#include "shell/common/node_util.h"
|
||||
#include "shell/common/v8_util.h"
|
||||
#include "shell/utility/ai/utility_ai_manager.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
|
||||
|
||||
namespace gin {
|
||||
|
||||
template <>
|
||||
struct Converter<on_device_model::mojom::ResponseConstraintPtr> {
|
||||
static v8::Local<v8::Value> ToV8(
|
||||
v8::Isolate* isolate,
|
||||
const on_device_model::mojom::ResponseConstraintPtr& val) {
|
||||
if (val.is_null())
|
||||
return v8::Undefined(isolate);
|
||||
|
||||
if (val->is_json_schema()) {
|
||||
return v8::JSON::Parse(isolate->GetCurrentContext(),
|
||||
StringToV8(isolate, val->get_json_schema()))
|
||||
.ToLocalChecked();
|
||||
} else if (val->is_regex()) {
|
||||
return v8::RegExp::New(isolate->GetCurrentContext(),
|
||||
StringToV8(isolate, val->get_regex()),
|
||||
v8::RegExp::kNone)
|
||||
.ToLocalChecked();
|
||||
}
|
||||
|
||||
return v8::Undefined(isolate);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<blink::mojom::AILanguageModelPromptRole> {
|
||||
static v8::Local<v8::Value> ToV8(
|
||||
v8::Isolate* isolate,
|
||||
blink::mojom::AILanguageModelPromptRole value) {
|
||||
switch (value) {
|
||||
case blink::mojom::AILanguageModelPromptRole::kSystem:
|
||||
return StringToV8(isolate, "system");
|
||||
case blink::mojom::AILanguageModelPromptRole::kUser:
|
||||
return StringToV8(isolate, "user");
|
||||
case blink::mojom::AILanguageModelPromptRole::kAssistant:
|
||||
return StringToV8(isolate, "assistant");
|
||||
default:
|
||||
return StringToV8(isolate, "unknown");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<blink::mojom::AILanguageModelPromptContentPtr> {
|
||||
static v8::Local<v8::Value> ToV8(
|
||||
v8::Isolate* isolate,
|
||||
const blink::mojom::AILanguageModelPromptContentPtr& val) {
|
||||
if (val.is_null())
|
||||
return v8::Undefined(isolate);
|
||||
|
||||
auto dict = gin::Dictionary::CreateEmpty(isolate);
|
||||
|
||||
if (val->is_text()) {
|
||||
dict.Set("type", "text");
|
||||
dict.Set("value", val->get_text());
|
||||
} else if (val->is_bitmap()) {
|
||||
// Convert the bitmap to an ArrayBuffer
|
||||
// TODO - Are we going to make any guarantees about the shape of the image
|
||||
// data?
|
||||
SkBitmap& bitmap = val->get_bitmap();
|
||||
|
||||
const auto dst_info = SkImageInfo::MakeN32Premul(bitmap.dimensions());
|
||||
const size_t dst_n_bytes = dst_info.computeMinByteSize();
|
||||
auto dst_buf = v8::ArrayBuffer::New(isolate, dst_n_bytes);
|
||||
|
||||
if (!bitmap.readPixels(dst_info, dst_buf->Data(), dst_info.minRowBytes(),
|
||||
0, 0)) {
|
||||
auto err = v8::Exception::TypeError(
|
||||
gin::StringToV8(isolate, "Invalid bitmap content in prompt"));
|
||||
node::errors::TriggerUncaughtException(isolate, err, {});
|
||||
}
|
||||
|
||||
dict.Set("type", "image");
|
||||
dict.Set("value", dst_buf);
|
||||
} else if (val->is_audio()) {
|
||||
// Convert the audio data to an ArrayBuffer
|
||||
// TODO - Are we going to make any guarantees about the shape of the audio
|
||||
// data?
|
||||
on_device_model::mojom::AudioDataPtr& audio_data = val->get_audio();
|
||||
std::vector<float>& raw_data = audio_data->data;
|
||||
|
||||
const size_t dst_n_bytes =
|
||||
sizeof(std::remove_reference_t<decltype(raw_data)>::value_type) *
|
||||
raw_data.size();
|
||||
auto dst_buf = v8::ArrayBuffer::New(isolate, dst_n_bytes);
|
||||
|
||||
UNSAFE_BUFFERS(
|
||||
std::ranges::copy(raw_data, static_cast<char*>(dst_buf->Data())));
|
||||
|
||||
dict.Set("type", "audio");
|
||||
dict.Set("value", dst_buf);
|
||||
}
|
||||
|
||||
return ConvertToV8(isolate, dict);
|
||||
}
|
||||
};
|
||||
|
||||
v8::Local<v8::Value> Converter<blink::mojom::AILanguageModelPromptPtr>::ToV8(
|
||||
v8::Isolate* isolate,
|
||||
const blink::mojom::AILanguageModelPromptPtr& val) {
|
||||
if (val.is_null())
|
||||
return v8::Undefined(isolate);
|
||||
|
||||
auto dict = gin::Dictionary::CreateEmpty(isolate);
|
||||
|
||||
dict.Set("role", val->role);
|
||||
dict.Set("content", val->content);
|
||||
dict.Set("prefix", val->is_prefix);
|
||||
|
||||
return ConvertToV8(isolate, dict);
|
||||
}
|
||||
|
||||
} // namespace gin
|
||||
|
||||
namespace electron {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr std::string_view kIsReadableStreamKey = "isReadableStream";
|
||||
constexpr std::string_view kIsLanguageModelKey = "isLanguageModel";
|
||||
constexpr std::string_view kIsLanguageModelClassKey = "isLanguageModelClass";
|
||||
|
||||
v8::Local<v8::Function> GetPrivateBoolean(v8::Isolate* const isolate,
|
||||
const v8::Local<v8::Context>& context,
|
||||
std::string_view key) {
|
||||
auto binding_key = gin::StringToV8(isolate, key);
|
||||
auto private_binding_key = v8::Private::ForApi(isolate, binding_key);
|
||||
auto global_object = context->Global();
|
||||
auto value =
|
||||
global_object->GetPrivate(context, private_binding_key).ToLocalChecked();
|
||||
if (value.IsEmpty() || !value->IsFunction()) {
|
||||
LOG(FATAL) << "Attempted to get the '" << key
|
||||
<< "' value but it was missing";
|
||||
}
|
||||
return value.As<v8::Function>();
|
||||
}
|
||||
|
||||
bool IsReadableStream(v8::Isolate* isolate, v8::Local<v8::Value> val) {
|
||||
static base::NoDestructor<v8::Global<v8::Function>> is_readable_stream;
|
||||
|
||||
auto context = isolate->GetCurrentContext();
|
||||
|
||||
if (is_readable_stream.get()->IsEmpty()) {
|
||||
is_readable_stream->Reset(
|
||||
isolate, GetPrivateBoolean(isolate, context, kIsReadableStreamKey));
|
||||
}
|
||||
|
||||
v8::Local<v8::Value> args[] = {val};
|
||||
v8::Local<v8::Value> result =
|
||||
is_readable_stream->Get(isolate)
|
||||
->Call(context, v8::Null(isolate), std::size(args), args)
|
||||
.ToLocalChecked();
|
||||
|
||||
return result->IsBoolean() && result.As<v8::Boolean>()->Value();
|
||||
}
|
||||
|
||||
uint64_t GetContextUsage(v8::Isolate* isolate,
|
||||
v8::Local<v8::Object> language_model) {
|
||||
auto context = isolate->GetCurrentContext();
|
||||
v8::Local<v8::Value> val =
|
||||
language_model->Get(context, gin::StringToV8(isolate, "contextUsage"))
|
||||
.ToLocalChecked();
|
||||
uint64_t token_count = 0;
|
||||
if (val->IsNumber()) {
|
||||
gin::ConvertFromV8(isolate, val, &token_count);
|
||||
}
|
||||
return token_count;
|
||||
}
|
||||
|
||||
// Owns itself. Will live as long as there's more data to process
|
||||
// and the Mojo remote is still connected.
|
||||
class PromptResponder {
|
||||
public:
|
||||
PromptResponder(v8::Isolate* isolate,
|
||||
v8::Local<v8::Value> value,
|
||||
v8::Local<v8::Object> abort_controller,
|
||||
v8::Local<v8::Object> language_model,
|
||||
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
|
||||
pending_responder,
|
||||
UtilityAILanguageModel* model) {
|
||||
abort_controller_.Reset(isolate, abort_controller);
|
||||
language_model_.Reset(isolate, language_model);
|
||||
responder_.Bind(std::move(pending_responder));
|
||||
responder_.set_disconnect_handler(
|
||||
base::BindOnce(&PromptResponder::DeleteThis, base::Unretained(this)));
|
||||
|
||||
destroy_subscription_ = model->AddDestroyObserver(base::BindRepeating(
|
||||
&PromptResponder::OnModelDestroyed, base::Unretained(this)));
|
||||
|
||||
Respond(isolate, value);
|
||||
}
|
||||
|
||||
// disable copy
|
||||
PromptResponder(const PromptResponder&) = delete;
|
||||
PromptResponder& operator=(const PromptResponder&) = delete;
|
||||
|
||||
private:
|
||||
void OnModelDestroyed() {
|
||||
// Drop the subscription since the model is already being destroyed.
|
||||
destroy_subscription_ = {};
|
||||
responder_->OnError(
|
||||
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
|
||||
/*quota_error_info=*/nullptr);
|
||||
DeleteThis();
|
||||
}
|
||||
|
||||
void Respond(v8::Isolate* isolate, v8::Local<v8::Value> value) {
|
||||
if (value->IsPromise()) {
|
||||
auto promise = value.As<v8::Promise>();
|
||||
|
||||
auto then_cb = base::BindOnce(
|
||||
[](base::WeakPtr<PromptResponder> weak_ptr, v8::Isolate* isolate,
|
||||
v8::Local<v8::Value> result) {
|
||||
if (weak_ptr) {
|
||||
weak_ptr->RespondImplementation(isolate, result);
|
||||
}
|
||||
},
|
||||
weak_ptr_factory_.GetWeakPtr(), isolate);
|
||||
|
||||
auto catch_cb = base::BindOnce(
|
||||
[](base::WeakPtr<PromptResponder> weak_ptr,
|
||||
v8::Local<v8::Value> result) {
|
||||
if (weak_ptr) {
|
||||
weak_ptr->SendError();
|
||||
weak_ptr->DeleteThis();
|
||||
}
|
||||
},
|
||||
weak_ptr_factory_.GetWeakPtr());
|
||||
|
||||
std::ignore = promise->Then(
|
||||
isolate->GetCurrentContext(),
|
||||
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
|
||||
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
|
||||
} else {
|
||||
RespondImplementation(isolate, value);
|
||||
}
|
||||
}
|
||||
|
||||
void RespondImplementation(v8::Isolate* isolate, v8::Local<v8::Value> val) {
|
||||
std::string response;
|
||||
|
||||
if (val->IsString() && gin::ConvertFromV8(isolate, val, &response)) {
|
||||
responder_->OnStreaming(response);
|
||||
uint64_t token_count =
|
||||
GetContextUsage(isolate, language_model_.Get(isolate));
|
||||
responder_->OnCompletion(
|
||||
blink::mojom::ModelExecutionContextInfo::New(token_count));
|
||||
completed_ = true;
|
||||
DeleteThis();
|
||||
} else if (IsReadableStream(isolate, val)) {
|
||||
v8::Local<v8::Value> reader =
|
||||
gin_helper::CallMethod(isolate, val.As<v8::Object>(), "getReader");
|
||||
DCHECK(reader->IsObject());
|
||||
readable_stream_reader_.Reset(isolate, reader.As<v8::Object>());
|
||||
Read(isolate);
|
||||
} else {
|
||||
SendError();
|
||||
DeleteThis();
|
||||
auto err = v8::Exception::TypeError(gin::StringToV8(
|
||||
isolate, "Invalid return value from LanguageModel.prompt()"));
|
||||
node::errors::TriggerUncaughtException(isolate, err, {});
|
||||
}
|
||||
}
|
||||
|
||||
void Read(v8::Isolate* isolate) {
|
||||
v8::Local<v8::Value> val = gin_helper::CallMethod(
|
||||
isolate, readable_stream_reader_.Get(isolate), "read");
|
||||
DCHECK(val->IsPromise());
|
||||
|
||||
auto promise = val.As<v8::Promise>();
|
||||
|
||||
auto then_cb = base::BindOnce(
|
||||
[](base::WeakPtr<PromptResponder> weak_ptr, v8::Isolate* isolate,
|
||||
v8::Local<v8::Value> result) {
|
||||
if (weak_ptr) {
|
||||
CHECK(result->IsObject());
|
||||
|
||||
v8::Local<v8::Value> done =
|
||||
result.As<v8::Object>()
|
||||
->Get(isolate->GetCurrentContext(),
|
||||
gin::StringToV8(isolate, "done"))
|
||||
.ToLocalChecked();
|
||||
CHECK(done->IsBoolean());
|
||||
|
||||
if (done.As<v8::Boolean>()->Value()) {
|
||||
uint64_t token_count = GetContextUsage(
|
||||
isolate, weak_ptr->language_model_.Get(isolate));
|
||||
weak_ptr->responder_->OnCompletion(
|
||||
blink::mojom::ModelExecutionContextInfo::New(token_count));
|
||||
weak_ptr->completed_ = true;
|
||||
weak_ptr->DeleteThis();
|
||||
} else {
|
||||
v8::Local<v8::Value> val =
|
||||
result.As<v8::Object>()
|
||||
->Get(isolate->GetCurrentContext(),
|
||||
gin::StringToV8(isolate, "value"))
|
||||
.ToLocalChecked();
|
||||
DCHECK(val->IsString());
|
||||
|
||||
std::string value;
|
||||
|
||||
if (gin::ConvertFromV8(isolate, val, &value)) {
|
||||
weak_ptr->responder_->OnStreaming(value);
|
||||
weak_ptr->Read(isolate);
|
||||
} else {
|
||||
weak_ptr->SendError();
|
||||
weak_ptr->DeleteThis();
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
weak_ptr_factory_.GetWeakPtr(), isolate);
|
||||
|
||||
auto catch_cb = base::BindOnce(
|
||||
[](base::WeakPtr<PromptResponder> weak_ptr,
|
||||
v8::Local<v8::Value> result) {
|
||||
if (weak_ptr) {
|
||||
weak_ptr->SendError();
|
||||
weak_ptr->DeleteThis();
|
||||
}
|
||||
},
|
||||
weak_ptr_factory_.GetWeakPtr());
|
||||
|
||||
std::ignore = promise->Then(
|
||||
isolate->GetCurrentContext(),
|
||||
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
|
||||
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
|
||||
}
|
||||
|
||||
void SendError() {
|
||||
responder_->OnError(
|
||||
blink::mojom::ModelStreamingResponseStatus::kErrorUnknown,
|
||||
/*quota_error_info=*/nullptr);
|
||||
}
|
||||
|
||||
void DeleteThis() {
|
||||
destroy_subscription_ = {};
|
||||
weak_ptr_factory_.InvalidateWeakPtrs();
|
||||
|
||||
if (!completed_) {
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
|
||||
if (!readable_stream_reader_.IsEmpty()) {
|
||||
gin_helper::CallMethod(isolate, readable_stream_reader_.Get(isolate),
|
||||
"cancel");
|
||||
}
|
||||
|
||||
gin_helper::CallMethod(isolate, abort_controller_.Get(isolate), "abort");
|
||||
}
|
||||
|
||||
delete this;
|
||||
}
|
||||
|
||||
bool completed_ = false;
|
||||
v8::Global<v8::Object> readable_stream_reader_;
|
||||
v8::Global<v8::Object> abort_controller_;
|
||||
v8::Global<v8::Object> language_model_;
|
||||
mojo::Remote<blink::mojom::ModelStreamingResponder> responder_;
|
||||
base::CallbackListSubscription destroy_subscription_;
|
||||
|
||||
base::WeakPtrFactory<PromptResponder> weak_ptr_factory_{this};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
UtilityAILanguageModel::UtilityAILanguageModel(
|
||||
v8::Local<v8::Object> language_model,
|
||||
base::WeakPtr<UtilityAIManager> manager)
|
||||
: manager_(std::move(manager)) {
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
language_model_.Reset(isolate, language_model);
|
||||
responder_set_.set_disconnect_handler(
|
||||
base::BindRepeating(&UtilityAILanguageModel::OnResponderDisconnect,
|
||||
weak_ptr_factory_.GetWeakPtr()));
|
||||
}
|
||||
|
||||
UtilityAILanguageModel::~UtilityAILanguageModel() {
|
||||
if (!is_destroyed_) {
|
||||
Destroy();
|
||||
}
|
||||
}
|
||||
|
||||
base::CallbackListSubscription UtilityAILanguageModel::AddDestroyObserver(
|
||||
base::RepeatingClosure callback) {
|
||||
return on_destroy_.Add(std::move(callback));
|
||||
}
|
||||
|
||||
void UtilityAILanguageModel::OnResponderDisconnect(
|
||||
mojo::RemoteSetElementId id) {
|
||||
auto it = abort_controllers_.find(id);
|
||||
if (it != abort_controllers_.end()) {
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
gin_helper::CallMethod(isolate, it->second.Get(isolate), "abort");
|
||||
abort_controllers_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
blink::mojom::ModelStreamingResponder* UtilityAILanguageModel::GetResponder(
|
||||
mojo::RemoteSetElementId responder_id) {
|
||||
return responder_set_.Get(responder_id);
|
||||
}
|
||||
|
||||
// static
|
||||
bool UtilityAILanguageModel::IsLanguageModel(v8::Isolate* isolate,
|
||||
v8::Local<v8::Value> val) {
|
||||
static base::NoDestructor<v8::Global<v8::Function>> is_language_model;
|
||||
|
||||
auto context = isolate->GetCurrentContext();
|
||||
|
||||
if (is_language_model.get()->IsEmpty()) {
|
||||
is_language_model->Reset(
|
||||
isolate, GetPrivateBoolean(isolate, context, kIsLanguageModelKey));
|
||||
}
|
||||
|
||||
v8::Local<v8::Value> args[] = {val};
|
||||
v8::Local<v8::Value> result =
|
||||
is_language_model->Get(isolate)
|
||||
->Call(context, v8::Null(isolate), std::size(args), args)
|
||||
.ToLocalChecked();
|
||||
|
||||
return result->IsBoolean() && result.As<v8::Boolean>()->Value();
|
||||
}
|
||||
|
||||
// static
|
||||
bool UtilityAILanguageModel::IsLanguageModelClass(v8::Isolate* isolate,
|
||||
v8::Local<v8::Value> val) {
|
||||
static base::NoDestructor<v8::Global<v8::Function>> is_language_model_class;
|
||||
|
||||
auto context = isolate->GetCurrentContext();
|
||||
|
||||
if (is_language_model_class.get()->IsEmpty()) {
|
||||
is_language_model_class->Reset(
|
||||
isolate, GetPrivateBoolean(isolate, context, kIsLanguageModelClassKey));
|
||||
}
|
||||
|
||||
v8::Local<v8::Value> args[] = {val};
|
||||
v8::Local<v8::Value> result =
|
||||
is_language_model_class->Get(isolate)
|
||||
->Call(context, v8::Null(isolate), std::size(args), args)
|
||||
.ToLocalChecked();
|
||||
|
||||
return result->IsBoolean() && result.As<v8::Boolean>()->Value();
|
||||
}
|
||||
|
||||
void UtilityAILanguageModel::Prompt(
|
||||
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
|
||||
on_device_model::mojom::ResponseConstraintPtr constraint,
|
||||
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
|
||||
pending_responder) {
|
||||
if (is_destroyed_) {
|
||||
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
|
||||
std::move(pending_responder));
|
||||
responder->OnError(
|
||||
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
|
||||
/*quota_error_info=*/nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
|
||||
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
|
||||
|
||||
auto options = gin_helper::Dictionary::CreateEmpty(isolate);
|
||||
if (!constraint.is_null()) {
|
||||
options.Set("responseConstraint", gin::ConvertToV8(isolate, constraint));
|
||||
}
|
||||
options.Set("signal", abort_controller
|
||||
->Get(isolate->GetCurrentContext(),
|
||||
gin::StringToV8(isolate, "signal"))
|
||||
.ToLocalChecked());
|
||||
|
||||
v8::Local<v8::Value> val = gin_helper::CallMethod(
|
||||
isolate, language_model_.Get(isolate), "prompt", prompts, options);
|
||||
|
||||
new PromptResponder(isolate, val, abort_controller,
|
||||
language_model_.Get(isolate),
|
||||
std::move(pending_responder), this);
|
||||
}
|
||||
|
||||
void UtilityAILanguageModel::Append(
|
||||
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
|
||||
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
|
||||
pending_responder) {
|
||||
if (is_destroyed_) {
|
||||
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
|
||||
std::move(pending_responder));
|
||||
responder->OnError(
|
||||
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
|
||||
/*quota_error_info=*/nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
mojo::RemoteSetElementId responder_id =
|
||||
responder_set_.Add(std::move(pending_responder));
|
||||
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
|
||||
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
|
||||
abort_controllers_.emplace(responder_id,
|
||||
v8::Global<v8::Object>(isolate, abort_controller));
|
||||
|
||||
auto options = gin_helper::Dictionary::CreateEmpty(isolate);
|
||||
options.Set("signal", abort_controller
|
||||
->Get(isolate->GetCurrentContext(),
|
||||
gin::StringToV8(isolate, "signal"))
|
||||
.ToLocalChecked());
|
||||
|
||||
v8::Local<v8::Value> val = gin_helper::CallMethod(
|
||||
isolate, language_model_.Get(isolate), "append", prompts, options);
|
||||
|
||||
auto SendResponse =
|
||||
[](base::WeakPtr<UtilityAILanguageModel> weak_ptr, v8::Isolate* isolate,
|
||||
mojo::RemoteSetElementId responder_id, v8::Local<v8::Value> result) {
|
||||
if (!weak_ptr)
|
||||
return;
|
||||
weak_ptr->abort_controllers_.erase(responder_id);
|
||||
|
||||
blink::mojom::ModelStreamingResponder* responder =
|
||||
weak_ptr->GetResponder(responder_id);
|
||||
if (!responder) {
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t token_count =
|
||||
GetContextUsage(isolate, weak_ptr->language_model_.Get(isolate));
|
||||
responder->OnCompletion(
|
||||
blink::mojom::ModelExecutionContextInfo::New(token_count));
|
||||
};
|
||||
|
||||
if (val->IsPromise()) {
|
||||
auto promise = val.As<v8::Promise>();
|
||||
|
||||
auto then_cb = base::BindOnce(SendResponse, weak_ptr_factory_.GetWeakPtr(),
|
||||
isolate, responder_id);
|
||||
|
||||
auto catch_cb = base::BindOnce(
|
||||
[](base::WeakPtr<UtilityAILanguageModel> weak_ptr,
|
||||
mojo::RemoteSetElementId responder_id, v8::Local<v8::Value> result) {
|
||||
if (!weak_ptr)
|
||||
return;
|
||||
weak_ptr->abort_controllers_.erase(responder_id);
|
||||
|
||||
blink::mojom::ModelStreamingResponder* responder =
|
||||
weak_ptr->GetResponder(responder_id);
|
||||
if (!responder) {
|
||||
return;
|
||||
}
|
||||
|
||||
responder->OnError(
|
||||
blink::mojom::ModelStreamingResponseStatus::kErrorUnknown,
|
||||
/*quota_error_info=*/nullptr);
|
||||
},
|
||||
weak_ptr_factory_.GetWeakPtr(), responder_id);
|
||||
|
||||
std::ignore = promise->Then(
|
||||
isolate->GetCurrentContext(),
|
||||
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
|
||||
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
|
||||
} else {
|
||||
// The method is supposed to return a promise, but for
|
||||
// convenience allow developers to return a value directly
|
||||
SendResponse(weak_ptr_factory_.GetWeakPtr(), isolate, responder_id, val);
|
||||
}
|
||||
}
|
||||
|
||||
void UtilityAILanguageModel::Fork(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client) {
|
||||
if (is_destroyed_ || !manager_) {
|
||||
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client_remote(std::move(client));
|
||||
client_remote->OnError(
|
||||
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession,
|
||||
/*quota_error_info=*/nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
|
||||
manager_->CreateLanguageModelInternal(
|
||||
isolate, std::move(client), language_model_.Get(isolate), "clone",
|
||||
gin_helper::Dictionary::CreateEmpty(isolate),
|
||||
blink::mojom::AILanguageModelCreateOptions::New());
|
||||
}
|
||||
|
||||
void UtilityAILanguageModel::Destroy() {
|
||||
if (is_destroyed_) {
|
||||
return;
|
||||
}
|
||||
|
||||
is_destroyed_ = true;
|
||||
|
||||
// Notify observers (e.g. in-progress PromptResponders) before
|
||||
// tearing down the responder set and abort controllers.
|
||||
on_destroy_.Notify();
|
||||
|
||||
for (auto& responder : responder_set_) {
|
||||
responder->OnError(
|
||||
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
|
||||
/*quota_error_info=*/nullptr);
|
||||
}
|
||||
responder_set_.Clear();
|
||||
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
|
||||
for (auto& [id, controller] : abort_controllers_) {
|
||||
gin_helper::CallMethod(isolate, controller.Get(isolate), "abort");
|
||||
}
|
||||
abort_controllers_.clear();
|
||||
|
||||
for (auto& controller : measure_abort_controllers_) {
|
||||
gin_helper::CallMethod(isolate, controller.Get(isolate), "abort");
|
||||
}
|
||||
measure_abort_controllers_.clear();
|
||||
|
||||
gin_helper::CallMethod(isolate, language_model_.Get(isolate), "destroy");
|
||||
}
|
||||
|
||||
void UtilityAILanguageModel::MeasureInputUsage(
|
||||
std::vector<blink::mojom::AILanguageModelPromptPtr> input,
|
||||
MeasureInputUsageCallback callback) {
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
|
||||
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
|
||||
measure_abort_controllers_.emplace_back(isolate, abort_controller);
|
||||
auto abort_it = std::prev(measure_abort_controllers_.end());
|
||||
|
||||
auto options = gin_helper::Dictionary::CreateEmpty(isolate);
|
||||
options.Set("signal", abort_controller
|
||||
->Get(isolate->GetCurrentContext(),
|
||||
gin::StringToV8(isolate, "signal"))
|
||||
.ToLocalChecked());
|
||||
|
||||
v8::Local<v8::Value> val =
|
||||
gin_helper::CallMethod(isolate, language_model_.Get(isolate),
|
||||
"measureContextUsage", input, options);
|
||||
|
||||
auto RunCallback = [](base::WeakPtr<UtilityAILanguageModel> weak_ptr,
|
||||
std::list<v8::Global<v8::Object>>::iterator abort_it,
|
||||
v8::Isolate* isolate,
|
||||
MeasureInputUsageCallback callback,
|
||||
v8::Local<v8::Value> result) {
|
||||
if (weak_ptr) {
|
||||
weak_ptr->measure_abort_controllers_.erase(abort_it);
|
||||
}
|
||||
|
||||
uint32_t input_tokens = 0;
|
||||
|
||||
if (result->IsNumber() &&
|
||||
gin::ConvertFromV8(isolate, result, &input_tokens)) {
|
||||
std::move(callback).Run(std::move(input_tokens));
|
||||
} else if (result->IsNull()) {
|
||||
std::move(callback).Run(std::nullopt);
|
||||
} else {
|
||||
std::move(callback).Run(std::nullopt);
|
||||
auto err = v8::Exception::TypeError(gin::StringToV8(
|
||||
isolate,
|
||||
"Invalid return value from LanguageModel.measureContextUsage()"));
|
||||
node::errors::TriggerUncaughtException(isolate, err, {});
|
||||
}
|
||||
};
|
||||
|
||||
if (val->IsPromise()) {
|
||||
auto promise = val.As<v8::Promise>();
|
||||
auto split_callback = base::SplitOnceCallback(std::move(callback));
|
||||
|
||||
auto then_cb =
|
||||
base::BindOnce(RunCallback, weak_ptr_factory_.GetWeakPtr(), abort_it,
|
||||
isolate, std::move(split_callback.first));
|
||||
|
||||
auto catch_cb = base::BindOnce(
|
||||
[](base::WeakPtr<UtilityAILanguageModel> weak_ptr,
|
||||
std::list<v8::Global<v8::Object>>::iterator abort_it,
|
||||
MeasureInputUsageCallback callback, v8::Local<v8::Value> result) {
|
||||
if (weak_ptr) {
|
||||
weak_ptr->measure_abort_controllers_.erase(abort_it);
|
||||
}
|
||||
std::move(callback).Run(std::nullopt);
|
||||
},
|
||||
weak_ptr_factory_.GetWeakPtr(), abort_it,
|
||||
std::move(split_callback.second));
|
||||
|
||||
std::ignore = promise->Then(
|
||||
isolate->GetCurrentContext(),
|
||||
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
|
||||
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
|
||||
} else {
|
||||
// The method is supposed to return a promise, but for
|
||||
// convenience allow developers to return a value directly
|
||||
RunCallback(weak_ptr_factory_.GetWeakPtr(), abort_it, isolate,
|
||||
std::move(callback), val);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace electron
|
||||
99
shell/utility/ai/utility_ai_language_model.h
Normal file
99
shell/utility/ai/utility_ai_language_model.h
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) 2025 Microsoft, Inc.
|
||||
// Use of this source code is governed by the MIT license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#ifndef ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_LANGUAGE_MODEL_H_
|
||||
#define ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_LANGUAGE_MODEL_H_
|
||||
|
||||
#include <list>
|
||||
#include <vector>
|
||||
|
||||
#include "base/callback_list.h"
|
||||
#include "base/memory/weak_ptr.h"
|
||||
#include "gin/converter.h"
|
||||
#include "mojo/public/cpp/bindings/pending_remote.h"
|
||||
#include "mojo/public/cpp/bindings/receiver.h"
|
||||
#include "mojo/public/cpp/bindings/remote_set.h"
|
||||
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
|
||||
#include "v8/include/v8.h"
|
||||
|
||||
namespace electron {
|
||||
|
||||
class UtilityAIManager;
|
||||
|
||||
class UtilityAILanguageModel : public blink::mojom::AILanguageModel {
|
||||
public:
|
||||
UtilityAILanguageModel(v8::Local<v8::Object> language_model,
|
||||
base::WeakPtr<UtilityAIManager> manager);
|
||||
UtilityAILanguageModel(const UtilityAILanguageModel&) = delete;
|
||||
UtilityAILanguageModel& operator=(const UtilityAILanguageModel&) = delete;
|
||||
|
||||
~UtilityAILanguageModel() override;
|
||||
|
||||
// Subscribe to be notified when this model is destroyed. The returned
|
||||
// subscription auto-unregisters when destroyed.
|
||||
[[nodiscard]] base::CallbackListSubscription AddDestroyObserver(
|
||||
base::RepeatingClosure callback);
|
||||
|
||||
static bool IsLanguageModel(v8::Isolate* isolate, v8::Local<v8::Value> val);
|
||||
static bool IsLanguageModelClass(v8::Isolate* isolate,
|
||||
v8::Local<v8::Value> val);
|
||||
|
||||
// `blink::mojom::AILanguageModel` implementation.
|
||||
void Prompt(std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
|
||||
on_device_model::mojom::ResponseConstraintPtr constraint,
|
||||
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
|
||||
pending_responder) override;
|
||||
void Append(std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
|
||||
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
|
||||
pending_responder) override;
|
||||
void Fork(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client) override;
|
||||
void Destroy() override;
|
||||
void MeasureInputUsage(
|
||||
std::vector<blink::mojom::AILanguageModelPromptPtr> input,
|
||||
MeasureInputUsageCallback callback) override;
|
||||
|
||||
private:
|
||||
void OnResponderDisconnect(mojo::RemoteSetElementId id);
|
||||
|
||||
blink::mojom::ModelStreamingResponder* GetResponder(
|
||||
mojo::RemoteSetElementId responder_id);
|
||||
|
||||
base::WeakPtr<UtilityAIManager> manager_;
|
||||
v8::Global<v8::Object> language_model_;
|
||||
bool is_destroyed_ = false;
|
||||
|
||||
mojo::RemoteSet<blink::mojom::ModelStreamingResponder> responder_set_;
|
||||
|
||||
// Maps each in-progress Prompt/Append responder to its AbortController
|
||||
// so we can abort the JS-side operation if the responder disconnects.
|
||||
absl::flat_hash_map<mojo::RemoteSetElementId, v8::Global<v8::Object>>
|
||||
abort_controllers_;
|
||||
|
||||
// Tracks abort controllers for in-progress MeasureInputUsage calls.
|
||||
std::list<v8::Global<v8::Object>> measure_abort_controllers_;
|
||||
|
||||
// Notified when this model is destroyed, allowing in-progress
|
||||
// PromptResponder instances to clean up.
|
||||
base::RepeatingClosureList on_destroy_;
|
||||
|
||||
base::WeakPtrFactory<UtilityAILanguageModel> weak_ptr_factory_{this};
|
||||
};
|
||||
|
||||
} // namespace electron
|
||||
|
||||
namespace gin {
|
||||
|
||||
template <>
|
||||
struct Converter<blink::mojom::AILanguageModelPromptPtr> {
|
||||
static v8::Local<v8::Value> ToV8(
|
||||
v8::Isolate* isolate,
|
||||
const blink::mojom::AILanguageModelPromptPtr& val);
|
||||
};
|
||||
|
||||
} // namespace gin
|
||||
|
||||
#endif // ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_LANGUAGE_MODEL_H_
|
||||
523
shell/utility/ai/utility_ai_manager.cc
Normal file
523
shell/utility/ai/utility_ai_manager.cc
Normal file
@@ -0,0 +1,523 @@
|
||||
// Copyright (c) 2025 Microsoft, Inc.
|
||||
// Use of this source code is governed by the MIT license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#include "shell/utility/ai/utility_ai_manager.h"
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
#include "base/containers/fixed_flat_map.h"
|
||||
#include "base/notimplemented.h"
|
||||
#include "mojo/public/cpp/bindings/unique_receiver_set.h"
|
||||
#include "shell/browser/javascript_environment.h"
|
||||
#include "shell/common/gin_converters/callback_converter.h"
|
||||
#include "shell/common/gin_converters/std_converter.h"
|
||||
#include "shell/common/gin_helper/dictionary.h"
|
||||
#include "shell/common/gin_helper/event_emitter_caller.h"
|
||||
#include "shell/common/node_includes.h"
|
||||
#include "shell/common/node_util.h"
|
||||
#include "shell/utility/ai/utility_ai_language_model.h"
|
||||
#include "shell/utility/api/electron_api_local_ai_handler.h"
|
||||
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_writer.mojom.h"
|
||||
#include "url/gurl.h"
|
||||
#include "url/origin.h"
|
||||
#include "v8/include/v8.h"
|
||||
|
||||
namespace gin {
|
||||
|
||||
template <>
|
||||
struct Converter<blink::mojom::ModelAvailabilityCheckResult> {
|
||||
static bool FromV8(v8::Isolate* isolate,
|
||||
v8::Local<v8::Value> val,
|
||||
blink::mojom::ModelAvailabilityCheckResult* out) {
|
||||
using Result = blink::mojom::ModelAvailabilityCheckResult;
|
||||
static constexpr auto Lookup =
|
||||
base::MakeFixedFlatMap<std::string_view, Result>({
|
||||
{"available", Result::kAvailable},
|
||||
{"unavailable", Result::kUnavailableUnknown},
|
||||
{"downloading", Result::kDownloading},
|
||||
{"downloadable", Result::kDownloadable},
|
||||
});
|
||||
return FromV8WithLookup(isolate, val, Lookup, out);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<blink::mojom::AILanguageModelPromptType> {
|
||||
static v8::Local<v8::Value> ToV8(
|
||||
v8::Isolate* isolate,
|
||||
blink::mojom::AILanguageModelPromptType value) {
|
||||
switch (value) {
|
||||
case blink::mojom::AILanguageModelPromptType::kText:
|
||||
return StringToV8(isolate, "text");
|
||||
case blink::mojom::AILanguageModelPromptType::kImage:
|
||||
return StringToV8(isolate, "image");
|
||||
case blink::mojom::AILanguageModelPromptType::kAudio:
|
||||
return StringToV8(isolate, "audio");
|
||||
default:
|
||||
return StringToV8(isolate, "unknown");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<blink::mojom::AILanguageCodePtr> {
|
||||
static v8::Local<v8::Value> ToV8(v8::Isolate* isolate,
|
||||
const blink::mojom::AILanguageCodePtr& val) {
|
||||
if (val.is_null()) {
|
||||
return v8::Undefined(isolate);
|
||||
}
|
||||
|
||||
return StringToV8(isolate, val->code);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<blink::mojom::AILanguageModelExpectedPtr> {
|
||||
static v8::Local<v8::Value> ToV8(
|
||||
v8::Isolate* isolate,
|
||||
const blink::mojom::AILanguageModelExpectedPtr& val) {
|
||||
if (val.is_null()) {
|
||||
return v8::Undefined(isolate);
|
||||
}
|
||||
|
||||
auto dict = gin::Dictionary::CreateEmpty(isolate);
|
||||
|
||||
dict.Set("type", val->type);
|
||||
|
||||
if (val->languages.has_value() && !val->languages->empty()) {
|
||||
dict.Set("languages", val->languages.value());
|
||||
}
|
||||
|
||||
return ConvertToV8(isolate, dict);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Converter<blink::mojom::AILanguageModelCreateOptionsPtr> {
|
||||
static v8::Local<v8::Value> ToV8(
|
||||
v8::Isolate* isolate,
|
||||
const blink::mojom::AILanguageModelCreateOptionsPtr& val) {
|
||||
if (val.is_null() ||
|
||||
(val->sampling_params.is_null() && !val->expected_inputs.has_value() &&
|
||||
!val->expected_outputs.has_value() && val->initial_prompts.empty())) {
|
||||
return v8::Undefined(isolate);
|
||||
}
|
||||
|
||||
auto dict = gin::Dictionary::CreateEmpty(isolate);
|
||||
|
||||
if (val->expected_inputs.has_value() && !val->expected_inputs->empty()) {
|
||||
dict.Set("expectedInputs", val->expected_inputs.value());
|
||||
}
|
||||
|
||||
if (val->expected_outputs.has_value() && !val->expected_outputs->empty()) {
|
||||
dict.Set("expectedOutputs", val->expected_outputs.value());
|
||||
}
|
||||
|
||||
if (!val->initial_prompts.empty()) {
|
||||
dict.Set("initialPrompts", val->initial_prompts);
|
||||
}
|
||||
|
||||
return ConvertToV8(isolate, dict);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace gin
|
||||
|
||||
namespace electron {
|
||||
|
||||
UtilityAIManager::UtilityAIManager(std::optional<int32_t> web_contents_id,
|
||||
const url::Origin& security_origin)
|
||||
: web_contents_id_(web_contents_id), security_origin_(security_origin) {
|
||||
create_model_client_set_.set_disconnect_with_reason_handler(
|
||||
base::BindRepeating(
|
||||
&UtilityAIManager::OnCreateLanguageModelClientDisconnect,
|
||||
weak_ptr_factory_.GetWeakPtr()));
|
||||
}
|
||||
|
||||
UtilityAIManager::~UtilityAIManager() {
|
||||
// Trigger the abort signal for any in-progress CreateLanguageModel calls
|
||||
for (auto it = create_model_client_set_.begin();
|
||||
it != create_model_client_set_.end(); ++it) {
|
||||
OnCreateLanguageModelClientDisconnect(it.id(), 0, std::string());
|
||||
}
|
||||
}
|
||||
|
||||
void UtilityAIManager::OnCreateLanguageModelClientDisconnect(
|
||||
mojo::RemoteSetElementId id,
|
||||
uint32_t custom_reason,
|
||||
const std::string& description) {
|
||||
auto it = abort_controllers_.find(id);
|
||||
if (it != abort_controllers_.end()) {
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
if (description.empty()) {
|
||||
gin_helper::CallMethod(isolate, it->second.Get(isolate), "abort");
|
||||
} else {
|
||||
gin_helper::CallMethod(isolate, it->second.Get(isolate), "abort",
|
||||
description);
|
||||
}
|
||||
abort_controllers_.erase(it);
|
||||
}
|
||||
}
|
||||
|
||||
v8::Global<v8::Object>& UtilityAIManager::GetLanguageModelClass() {
|
||||
if (language_model_class_.IsEmpty()) {
|
||||
auto& handler = electron::api::local_ai_handler::GetPromptAPIHandler();
|
||||
|
||||
if (handler.has_value()) {
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
|
||||
auto details = gin_helper::Dictionary::CreateEmpty(isolate);
|
||||
if (web_contents_id_.has_value()) {
|
||||
details.Set("webContentsId", web_contents_id_.value());
|
||||
} else {
|
||||
details.Set("webContentsId", nullptr);
|
||||
}
|
||||
details.Set("securityOrigin", security_origin_.Serialize());
|
||||
|
||||
v8::Local<v8::Value> val = handler->Run(details);
|
||||
|
||||
if (val->IsPromise()) {
|
||||
auto err = v8::Exception::TypeError(gin::StringToV8(
|
||||
isolate, "Cannot return a promise from the handler"));
|
||||
node::errors::TriggerUncaughtException(isolate, err, {});
|
||||
return language_model_class_;
|
||||
}
|
||||
|
||||
if (!val->IsObject() ||
|
||||
!val->ToObject(isolate->GetCurrentContext())
|
||||
.ToLocalChecked()
|
||||
->IsConstructor() ||
|
||||
!UtilityAILanguageModel::IsLanguageModelClass(isolate, val)) {
|
||||
auto err = v8::Exception::TypeError(
|
||||
gin::StringToV8(isolate, "Must provide a constructible class"));
|
||||
node::errors::TriggerUncaughtException(isolate, err, {});
|
||||
return language_model_class_;
|
||||
} else {
|
||||
language_model_class_.Reset(
|
||||
isolate,
|
||||
val->ToObject(isolate->GetCurrentContext()).ToLocalChecked());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return language_model_class_;
|
||||
}
|
||||
|
||||
void UtilityAIManager::SendCreateLanguageModelError(
|
||||
mojo::RemoteSetElementId client_id,
|
||||
blink::mojom::AIManagerCreateClientError error) {
|
||||
abort_controllers_.erase(client_id);
|
||||
|
||||
blink::mojom::AIManagerCreateLanguageModelClient* client =
|
||||
create_model_client_set_.Get(client_id);
|
||||
if (!client) {
|
||||
return;
|
||||
}
|
||||
|
||||
client->OnError(error, /*quota_error_info=*/nullptr);
|
||||
}
|
||||
|
||||
void UtilityAIManager::HandleLanguageModelResult(
|
||||
v8::Isolate* isolate,
|
||||
v8::Local<v8::Object> language_model,
|
||||
mojo::RemoteSetElementId client_id,
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options) {
|
||||
abort_controllers_.erase(client_id);
|
||||
|
||||
gin_helper::Dictionary dict;
|
||||
uint64_t context_usage = 0;
|
||||
uint64_t context_quota = 0;
|
||||
|
||||
if (!ConvertFromV8(isolate, language_model, &dict) ||
|
||||
!dict.Get("contextUsage", &context_usage) ||
|
||||
!dict.Get("contextWindow", &context_quota)) {
|
||||
SendCreateLanguageModelError(
|
||||
client_id,
|
||||
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO - How can the implementation specify the supported prompt types? For
|
||||
// now, assume all types are supported if the handler returns a valid object
|
||||
base::flat_set<blink::mojom::AILanguageModelPromptType> enabled_input_types;
|
||||
if (options->expected_inputs.has_value()) {
|
||||
for (const auto& expected_input : options->expected_inputs.value()) {
|
||||
enabled_input_types.insert(expected_input->type);
|
||||
}
|
||||
}
|
||||
|
||||
blink::mojom::AIManagerCreateLanguageModelClient* client =
|
||||
create_model_client_set_.Get(client_id);
|
||||
if (!client) {
|
||||
return;
|
||||
}
|
||||
|
||||
mojo::PendingRemote<blink::mojom::AILanguageModel> language_model_remote;
|
||||
|
||||
language_model_receivers_.Add(
|
||||
std::make_unique<UtilityAILanguageModel>(language_model,
|
||||
weak_ptr_factory_.GetWeakPtr()),
|
||||
language_model_remote.InitWithNewPipeAndPassReceiver());
|
||||
|
||||
client->OnResult(
|
||||
std::move(language_model_remote),
|
||||
blink::mojom::AILanguageModelInstanceInfo::New(
|
||||
context_quota, context_usage,
|
||||
blink::mojom::AILanguageModelSamplingParams::New(),
|
||||
std::vector<blink::mojom::AILanguageModelPromptType>(
|
||||
enabled_input_types.begin(), enabled_input_types.end()),
|
||||
/*audio_sample_rate_hz=*/std::nullopt,
|
||||
/*audio_channel_count=*/std::nullopt));
|
||||
}
|
||||
|
||||
void UtilityAIManager::CanCreateLanguageModel(
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options,
|
||||
CanCreateLanguageModelCallback callback) {
|
||||
v8::Global<v8::Object>& language_model_class = GetLanguageModelClass();
|
||||
blink::mojom::ModelAvailabilityCheckResult availability =
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown;
|
||||
|
||||
if (language_model_class.IsEmpty()) {
|
||||
std::move(callback).Run(availability);
|
||||
} else {
|
||||
// If a handler is set, we can create a language model.
|
||||
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
|
||||
v8::Local<v8::Value> val = gin_helper::CallMethod(
|
||||
isolate, language_model_class.Get(isolate), "availability", options);
|
||||
|
||||
auto RunCallback = [](v8::Isolate* isolate,
|
||||
CanCreateLanguageModelCallback callback,
|
||||
v8::Local<v8::Value> result) {
|
||||
blink::mojom::ModelAvailabilityCheckResult availability =
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown;
|
||||
|
||||
if (result->IsString() &&
|
||||
gin::ConvertFromV8(isolate, result, &availability)) {
|
||||
std::move(callback).Run(availability);
|
||||
} else {
|
||||
auto err = v8::Exception::TypeError(gin::StringToV8(
|
||||
isolate, "Invalid return value from LanguageModel.availability()"));
|
||||
node::errors::TriggerUncaughtException(isolate, err, {});
|
||||
std::move(callback).Run(
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
|
||||
}
|
||||
};
|
||||
|
||||
if (val->IsPromise()) {
|
||||
auto promise = val.As<v8::Promise>();
|
||||
auto split_callback = base::SplitOnceCallback(std::move(callback));
|
||||
|
||||
auto then_cb =
|
||||
base::BindOnce(RunCallback, isolate, std::move(split_callback.first));
|
||||
|
||||
auto catch_cb = base::BindOnce(
|
||||
[](CanCreateLanguageModelCallback callback,
|
||||
v8::Local<v8::Value> result) {
|
||||
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
|
||||
kUnavailableUnknown);
|
||||
},
|
||||
std::move(split_callback.second));
|
||||
|
||||
std::ignore = promise->Then(
|
||||
isolate->GetCurrentContext(),
|
||||
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
|
||||
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
|
||||
} else {
|
||||
// The method is supposed to return a promise, but for
|
||||
// convenience allow developers to return a value directly
|
||||
RunCallback(isolate, std::move(callback), val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void UtilityAIManager::CreateLanguageModel(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client,
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options) {
|
||||
v8::Global<v8::Object>& language_model_class = GetLanguageModelClass();
|
||||
|
||||
// Can't create language model if there's no language model class
|
||||
if (language_model_class.IsEmpty()) {
|
||||
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client_remote(std::move(client));
|
||||
client_remote->OnError(
|
||||
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession,
|
||||
/*quota_error_info=*/nullptr);
|
||||
return;
|
||||
}
|
||||
|
||||
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
|
||||
v8::HandleScope scope{isolate};
|
||||
|
||||
gin_helper::Dictionary options_dict{
|
||||
isolate, gin::ConvertToV8(isolate, options).As<v8::Object>()};
|
||||
|
||||
CreateLanguageModelInternal(isolate, std::move(client),
|
||||
language_model_class.Get(isolate), "create",
|
||||
std::move(options_dict), std::move(options));
|
||||
}
|
||||
|
||||
void UtilityAIManager::CreateLanguageModelInternal(
|
||||
v8::Isolate* isolate,
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client,
|
||||
v8::Local<v8::Object> target,
|
||||
std::string_view method_name,
|
||||
gin_helper::Dictionary options_dict,
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options) {
|
||||
DCHECK(method_name == "create" || method_name == "clone");
|
||||
|
||||
std::string error_source = "LanguageModel." + std::string(method_name) + "()";
|
||||
|
||||
mojo::RemoteSetElementId client_id =
|
||||
create_model_client_set_.Add(std::move(client));
|
||||
|
||||
// Store the abort controller so the disconnect handler can abort it.
|
||||
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
|
||||
|
||||
abort_controllers_.emplace(client_id,
|
||||
v8::Global<v8::Object>(isolate, abort_controller));
|
||||
|
||||
options_dict.Set("signal", abort_controller
|
||||
->Get(isolate->GetCurrentContext(),
|
||||
gin::StringToV8(isolate, "signal"))
|
||||
.ToLocalChecked());
|
||||
|
||||
v8::Local<v8::Value> val =
|
||||
gin_helper::CallMethod(isolate, target, method_name.data(), options_dict);
|
||||
|
||||
if (val->IsPromise()) {
|
||||
auto promise = val.As<v8::Promise>();
|
||||
|
||||
auto then_cb = base::BindOnce(
|
||||
[](base::WeakPtr<UtilityAIManager> weak_ptr, v8::Isolate* isolate,
|
||||
mojo::RemoteSetElementId client_id,
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options,
|
||||
std::string error_source, v8::Local<v8::Value> result) {
|
||||
if (weak_ptr) {
|
||||
if (result->IsObject() &&
|
||||
UtilityAILanguageModel::IsLanguageModel(isolate, result)) {
|
||||
weak_ptr->HandleLanguageModelResult(
|
||||
isolate, result.As<v8::Object>(), client_id,
|
||||
std::move(options));
|
||||
} else {
|
||||
auto err = v8::Exception::TypeError(gin::StringToV8(
|
||||
isolate, "Invalid return value from " + error_source));
|
||||
node::errors::TriggerUncaughtException(isolate, err, {});
|
||||
weak_ptr->SendCreateLanguageModelError(
|
||||
client_id, blink::mojom::AIManagerCreateClientError::
|
||||
kUnableToCreateSession);
|
||||
}
|
||||
}
|
||||
},
|
||||
weak_ptr_factory_.GetWeakPtr(), isolate, client_id, std::move(options),
|
||||
std::string(error_source));
|
||||
|
||||
auto catch_cb = base::BindOnce(
|
||||
[](base::WeakPtr<UtilityAIManager> weak_ptr,
|
||||
mojo::RemoteSetElementId client_id, v8::Local<v8::Value> result) {
|
||||
if (weak_ptr) {
|
||||
weak_ptr->SendCreateLanguageModelError(
|
||||
client_id, blink::mojom::AIManagerCreateClientError::
|
||||
kUnableToCreateSession);
|
||||
}
|
||||
},
|
||||
weak_ptr_factory_.GetWeakPtr(), client_id);
|
||||
|
||||
std::ignore = promise->Then(
|
||||
isolate->GetCurrentContext(),
|
||||
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
|
||||
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
|
||||
} else if (val->IsObject() &&
|
||||
UtilityAILanguageModel::IsLanguageModel(isolate, val)) {
|
||||
// The method is supposed to return a promise, but for
|
||||
// convenience allow developers to return a value directly
|
||||
HandleLanguageModelResult(isolate, val.As<v8::Object>(), client_id,
|
||||
std::move(options));
|
||||
} else {
|
||||
auto err = v8::Exception::TypeError(gin::StringToV8(
|
||||
isolate, "Invalid return value from " + std::string(error_source)));
|
||||
node::errors::TriggerUncaughtException(isolate, err, {});
|
||||
SendCreateLanguageModelError(
|
||||
client_id,
|
||||
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
|
||||
}
|
||||
}
|
||||
|
||||
void UtilityAIManager::CanCreateSummarizer(
|
||||
blink::mojom::AISummarizerCreateOptionsPtr options,
|
||||
CanCreateSummarizerCallback callback) {
|
||||
std::move(callback).Run(
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
|
||||
}
|
||||
|
||||
void UtilityAIManager::CreateSummarizer(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
|
||||
blink::mojom::AISummarizerCreateOptionsPtr options) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
void UtilityAIManager::GetLanguageModelParams(
|
||||
GetLanguageModelParamsCallback callback) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
void UtilityAIManager::CanCreateWriter(
|
||||
blink::mojom::AIWriterCreateOptionsPtr options,
|
||||
CanCreateWriterCallback callback) {
|
||||
std::move(callback).Run(
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
|
||||
}
|
||||
|
||||
void UtilityAIManager::CreateWriter(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
|
||||
blink::mojom::AIWriterCreateOptionsPtr options) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
void UtilityAIManager::CanCreateRewriter(
|
||||
blink::mojom::AIRewriterCreateOptionsPtr options,
|
||||
CanCreateRewriterCallback callback) {
|
||||
std::move(callback).Run(
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
|
||||
}
|
||||
|
||||
void UtilityAIManager::CreateRewriter(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
|
||||
blink::mojom::AIRewriterCreateOptionsPtr options) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
void UtilityAIManager::CanCreateProofreader(
|
||||
blink::mojom::AIProofreaderCreateOptionsPtr options,
|
||||
CanCreateProofreaderCallback callback) {
|
||||
std::move(callback).Run(
|
||||
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
|
||||
}
|
||||
|
||||
void UtilityAIManager::CreateProofreader(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient> client,
|
||||
blink::mojom::AIProofreaderCreateOptionsPtr options) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
void UtilityAIManager::AddModelDownloadProgressObserver(
|
||||
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
|
||||
observer_remote) {
|
||||
NOTIMPLEMENTED();
|
||||
}
|
||||
|
||||
} // namespace electron
|
||||
125
shell/utility/ai/utility_ai_manager.h
Normal file
125
shell/utility/ai/utility_ai_manager.h
Normal file
@@ -0,0 +1,125 @@
|
||||
// Copyright (c) 2025 Microsoft, Inc.
|
||||
// Use of this source code is governed by the MIT license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#ifndef ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_MANAGER_H_
|
||||
#define ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_MANAGER_H_
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
|
||||
#include "base/memory/weak_ptr.h"
|
||||
#include "mojo/public/cpp/bindings/pending_remote.h"
|
||||
#include "mojo/public/cpp/bindings/remote_set.h"
|
||||
#include "mojo/public/cpp/bindings/unique_receiver_set.h"
|
||||
#include "services/on_device_model/public/mojom/download_observer.mojom-forward.h"
|
||||
#include "shell/common/gin_helper/dictionary.h"
|
||||
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom-forward.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom-forward.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom-forward.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom-forward.h"
|
||||
#include "third_party/blink/public/mojom/ai/ai_writer.mojom-forward.h"
|
||||
#include "url/origin.h"
|
||||
#include "v8/include/v8.h"
|
||||
|
||||
namespace electron {
|
||||
|
||||
class UtilityAILanguageModel;
|
||||
|
||||
// The utility-side implementation of `blink::mojom::AIManager`.
|
||||
class UtilityAIManager : public blink::mojom::AIManager {
|
||||
public:
|
||||
UtilityAIManager(std::optional<int32_t> web_contents_id,
|
||||
const url::Origin& security_origin);
|
||||
UtilityAIManager(const UtilityAIManager&) = delete;
|
||||
UtilityAIManager& operator=(const UtilityAIManager&) = delete;
|
||||
|
||||
~UtilityAIManager() override;
|
||||
|
||||
private:
|
||||
friend class UtilityAILanguageModel;
|
||||
|
||||
void OnCreateLanguageModelClientDisconnect(mojo::RemoteSetElementId id,
|
||||
uint32_t custom_reason,
|
||||
const std::string& description);
|
||||
[[nodiscard]] v8::Global<v8::Object>& GetLanguageModelClass();
|
||||
|
||||
void SendCreateLanguageModelError(
|
||||
mojo::RemoteSetElementId client_id,
|
||||
blink::mojom::AIManagerCreateClientError error);
|
||||
|
||||
void CreateLanguageModelInternal(
|
||||
v8::Isolate* isolate,
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client,
|
||||
v8::Local<v8::Object> target,
|
||||
std::string_view method_name,
|
||||
gin_helper::Dictionary options_dict,
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options);
|
||||
|
||||
void HandleLanguageModelResult(
|
||||
v8::Isolate* isolate,
|
||||
v8::Local<v8::Object> language_model,
|
||||
mojo::RemoteSetElementId client_id,
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options);
|
||||
|
||||
// `blink::mojom::AIManager` implementation.
|
||||
void CanCreateLanguageModel(
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options,
|
||||
CanCreateLanguageModelCallback callback) override;
|
||||
void CreateLanguageModel(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
client,
|
||||
blink::mojom::AILanguageModelCreateOptionsPtr options) override;
|
||||
void CanCreateSummarizer(blink::mojom::AISummarizerCreateOptionsPtr options,
|
||||
CanCreateSummarizerCallback callback) override;
|
||||
void CreateSummarizer(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
|
||||
blink::mojom::AISummarizerCreateOptionsPtr options) override;
|
||||
void GetLanguageModelParams(GetLanguageModelParamsCallback callback) override;
|
||||
void CanCreateWriter(blink::mojom::AIWriterCreateOptionsPtr options,
|
||||
CanCreateWriterCallback callback) override;
|
||||
void CreateWriter(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
|
||||
blink::mojom::AIWriterCreateOptionsPtr options) override;
|
||||
void CanCreateRewriter(blink::mojom::AIRewriterCreateOptionsPtr options,
|
||||
CanCreateRewriterCallback callback) override;
|
||||
void CreateRewriter(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
|
||||
blink::mojom::AIRewriterCreateOptionsPtr options) override;
|
||||
void CanCreateProofreader(blink::mojom::AIProofreaderCreateOptionsPtr options,
|
||||
CanCreateProofreaderCallback callback) override;
|
||||
void CreateProofreader(
|
||||
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient>
|
||||
client,
|
||||
blink::mojom::AIProofreaderCreateOptionsPtr options) override;
|
||||
void AddModelDownloadProgressObserver(
|
||||
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
|
||||
observer_remote) override;
|
||||
|
||||
std::optional<int32_t> web_contents_id_;
|
||||
url::Origin security_origin_;
|
||||
|
||||
v8::Global<v8::Object> language_model_class_;
|
||||
|
||||
mojo::RemoteSet<blink::mojom::AIManagerCreateLanguageModelClient>
|
||||
create_model_client_set_;
|
||||
|
||||
// Maps each in-progress CreateLanguageModel client to its AbortController
|
||||
// so we can abort the JS-side operation if the client disconnects.
|
||||
absl::flat_hash_map<mojo::RemoteSetElementId, v8::Global<v8::Object>>
|
||||
abort_controllers_;
|
||||
|
||||
// Owns all created UtilityAILanguageModel instances
|
||||
mojo::UniqueReceiverSet<blink::mojom::AILanguageModel>
|
||||
language_model_receivers_;
|
||||
|
||||
base::WeakPtrFactory<UtilityAIManager> weak_ptr_factory_{this};
|
||||
};
|
||||
|
||||
} // namespace electron
|
||||
|
||||
#endif // ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_MANAGER_H_
|
||||
53
shell/utility/api/electron_api_local_ai_handler.cc
Normal file
53
shell/utility/api/electron_api_local_ai_handler.cc
Normal file
@@ -0,0 +1,53 @@
|
||||
// Copyright (c) 2025 Microsoft, Inc.
|
||||
// Use of this source code is governed by the MIT license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#include "shell/utility/api/electron_api_local_ai_handler.h"
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "base/no_destructor.h"
|
||||
#include "shell/common/gin_converters/callback_converter.h"
|
||||
#include "shell/common/gin_helper/dictionary.h"
|
||||
#include "shell/common/node_includes.h"
|
||||
#include "v8/include/v8.h"
|
||||
|
||||
namespace electron::api::local_ai_handler {
|
||||
|
||||
void SetPromptAPIHandler(v8::Isolate* isolate, v8::Local<v8::Value> val) {
|
||||
PromptAPIHandler handler;
|
||||
if (!(val->IsNull() || gin::ConvertFromV8(isolate, val, &handler))) {
|
||||
isolate->ThrowException(v8::Exception::TypeError(
|
||||
gin::StringToV8(isolate, "Must pass null or function")));
|
||||
return;
|
||||
}
|
||||
|
||||
if (val->IsNull()) {
|
||||
GetPromptAPIHandler() = std::nullopt;
|
||||
} else {
|
||||
GetPromptAPIHandler() = handler;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<PromptAPIHandler>& GetPromptAPIHandler() {
|
||||
static base::NoDestructor<std::optional<PromptAPIHandler>> prompt_api_handler;
|
||||
return *prompt_api_handler;
|
||||
}
|
||||
|
||||
} // namespace electron::api::local_ai_handler
|
||||
|
||||
namespace {
|
||||
|
||||
void Initialize(v8::Local<v8::Object> exports,
|
||||
v8::Local<v8::Value> unused,
|
||||
v8::Local<v8::Context> context,
|
||||
void* priv) {
|
||||
v8::Isolate* const isolate = v8::Isolate::GetCurrent();
|
||||
gin_helper::Dictionary dict{isolate, exports};
|
||||
dict.SetMethod("setPromptAPIHandler",
|
||||
&electron::api::local_ai_handler::SetPromptAPIHandler);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
NODE_LINKED_BINDING_CONTEXT_AWARE(electron_utility_local_ai_handler, Initialize)
|
||||
28
shell/utility/api/electron_api_local_ai_handler.h
Normal file
28
shell/utility/api/electron_api_local_ai_handler.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright (c) 2025 Microsoft, Inc.
|
||||
// Use of this source code is governed by the MIT license that can be
|
||||
// found in the LICENSE file.
|
||||
|
||||
#ifndef ELECTRON_SHELL_UTILITY_API_ELECTRON_LOCAL_AI_HANDLER_H_
|
||||
#define ELECTRON_SHELL_UTILITY_API_ELECTRON_LOCAL_AI_HANDLER_H_
|
||||
|
||||
#include <optional>
|
||||
|
||||
#include "base/functional/callback_forward.h"
|
||||
#include "v8/include/v8-forward.h"
|
||||
|
||||
namespace gin_helper {
|
||||
class Dictionary;
|
||||
}
|
||||
|
||||
namespace electron::api::local_ai_handler {
|
||||
|
||||
using PromptAPIHandler =
|
||||
base::RepeatingCallback<v8::Local<v8::Value>(gin_helper::Dictionary)>;
|
||||
|
||||
void SetPromptAPIHandler(v8::Isolate* isolate, v8::Local<v8::Value> value);
|
||||
|
||||
[[nodiscard]] std::optional<PromptAPIHandler>& GetPromptAPIHandler();
|
||||
|
||||
} // namespace electron::api::local_ai_handler
|
||||
|
||||
#endif // ELECTRON_SHELL_UTILITY_API_ELECTRON_LOCAL_AI_HANDLER_H_
|
||||
995
spec/api-local-ai-handler-spec.ts
Normal file
995
spec/api-local-ai-handler-spec.ts
Normal file
@@ -0,0 +1,995 @@
|
||||
import { BrowserWindow, session, utilityProcess } from 'electron/main';
|
||||
|
||||
import { expect } from 'chai';
|
||||
|
||||
import { on, once } from 'node:events';
|
||||
import * as path from 'node:path';
|
||||
|
||||
import { ifdescribe } from './lib/spec-helpers';
|
||||
import { closeAllWindows } from './lib/window-helpers';
|
||||
|
||||
const features = process._linkedBinding('electron_common_features');
|
||||
|
||||
function getFixturePath (fixtureName: string) {
|
||||
return path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), fixtureName);
|
||||
}
|
||||
|
||||
// Await fn and listen for a message of the given type, returning the message once received
|
||||
// Used to listen for a message triggered as a side effect of fn, where we don't care about the result of fn
|
||||
async function listenForMessage (aiHandler: Electron.UtilityProcess, messageType: string, fn: () => Promise<void> | void) {
|
||||
const messages = on(aiHandler, 'message');
|
||||
await fn();
|
||||
|
||||
for await (const [message] of messages) {
|
||||
if (message.type === messageType) {
|
||||
return message;
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// Call fn and await a message of the given type, returning the message and the promise returned by fn
|
||||
// Used to listen for a message triggered as a side effect of fn, where we do care about the result of fn
|
||||
async function waitForMessage (aiHandler: Electron.UtilityProcess, messageType: string, fn: () => Promise<unknown>) {
|
||||
let promise: Promise<unknown>;
|
||||
|
||||
await listenForMessage(aiHandler, messageType, () => {
|
||||
promise = fn();
|
||||
});
|
||||
|
||||
return { promise: promise! };
|
||||
}
|
||||
|
||||
ifdescribe(features.isPromptAPIEnabled())('localAIHandler module', () => {
|
||||
const fixtures = path.resolve(__dirname, 'fixtures');
|
||||
|
||||
let w: Electron.BrowserWindow;
|
||||
|
||||
async function forkAndRegisterHandler (fixtureName: string) {
|
||||
const aiHandler = utilityProcess.fork(getFixturePath(fixtureName));
|
||||
await once(aiHandler, 'spawn');
|
||||
w.webContents.session.registerLocalAIHandler(aiHandler);
|
||||
|
||||
return aiHandler;
|
||||
}
|
||||
|
||||
async function sendControllableMessage (aiHandler: Electron.UtilityProcess, message: unknown) {
|
||||
const ackEvent = once(aiHandler, 'message');
|
||||
aiHandler.postMessage(message);
|
||||
await ackEvent;
|
||||
}
|
||||
|
||||
beforeEach(async () => {
|
||||
w = new BrowserWindow({
|
||||
show: false,
|
||||
webPreferences: {
|
||||
enableBlinkFeatures: 'AIPromptAPI,AIPromptAPIMultimodalInput'
|
||||
}
|
||||
});
|
||||
|
||||
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
w.webContents.session.registerLocalAIHandler(null);
|
||||
closeAllWindows();
|
||||
});
|
||||
|
||||
describe('LanguageModel.availability()', () => {
|
||||
it('is unavailable if invalid value returned', async () => {
|
||||
await forkAndRegisterHandler('buggy-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
|
||||
it('returns "available" when handler reports available', async () => {
|
||||
await forkAndRegisterHandler('default-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
});
|
||||
|
||||
it('returns "downloadable" when handler reports downloadable', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'downloadable' });
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('downloadable');
|
||||
});
|
||||
|
||||
it('returns "downloading" when handler reports downloading', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'downloading' });
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('downloading');
|
||||
});
|
||||
|
||||
it('returns "unavailable" when handler reports unavailable', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'unavailable' });
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
|
||||
it('returns "unavailable" when the availability() promise rejects', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'reject' });
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
|
||||
it('returns "unavailable" if the utility process dies', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('default-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
aiHandler.kill();
|
||||
await once(aiHandler, 'exit');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
|
||||
it('returns "unavailable" if not registered', async () => {
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
|
||||
it('returns "unavailable" if registered but utility process has not set handler', async () => {
|
||||
await forkAndRegisterHandler('no-language-model.js');
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
|
||||
it('passes options to the availability() call', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'downloading' });
|
||||
|
||||
const options = { expectedInputs: [{ type: 'image' }, { type: 'text', languages: ['en', 'fr'] }], expectedOutputs: [{ type: 'image' }, { type: 'text', languages: ['en', 'fr'] }] };
|
||||
|
||||
const message = once(aiHandler, 'message');
|
||||
await w.webContents.executeJavaScript(`LanguageModel.availability(${JSON.stringify(options)})`);
|
||||
const [receivedMessage] = await message;
|
||||
|
||||
expect(receivedMessage.options).to.deep.equal(options);
|
||||
expect(receivedMessage.type).to.equal('availability-called');
|
||||
});
|
||||
});
|
||||
|
||||
describe('LanguageModel.create()', () => {
|
||||
async function expectRejectedWithError (message: string | RegExp, options?: Object) {
|
||||
// Unwrap the error message because NotAllowedError won't serialize
|
||||
if (options) {
|
||||
await expect(w.webContents.executeJavaScript(`LanguageModel.create(${JSON.stringify(options)}).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
|
||||
} else {
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(message);
|
||||
}
|
||||
}
|
||||
|
||||
it('rejects if invalid value returned', async () => {
|
||||
await forkAndRegisterHandler('buggy-language-model.js');
|
||||
|
||||
await expectRejectedWithError(/unable to create/);
|
||||
});
|
||||
|
||||
it('rejects when no handler is registered', async () => {
|
||||
await expectRejectedWithError(/unable to create/);
|
||||
});
|
||||
|
||||
it('rejects when handler promise rejects', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'reject' });
|
||||
|
||||
await expectRejectedWithError(/unable to create/);
|
||||
});
|
||||
|
||||
it('rejects if the utility process dies during creation', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'create-called', () => {
|
||||
return w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })');
|
||||
});
|
||||
|
||||
aiHandler.kill();
|
||||
await once(aiHandler, 'exit');
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/unable to create/);
|
||||
});
|
||||
|
||||
it('rejects if the handler gets unregistered during creation', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'create-called', () => {
|
||||
return w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })');
|
||||
});
|
||||
|
||||
w.webContents.session.registerLocalAIHandler(null);
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/unable to create/);
|
||||
});
|
||||
|
||||
it('creates a LanguageModel instance from a valid handler', async () => {
|
||||
await forkAndRegisterHandler('default-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model instanceof LanguageModel)')).to.equal(true);
|
||||
});
|
||||
|
||||
it('passes initialPrompts to create()', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
|
||||
const options = { initialPrompts: [{ role: 'system', content: [{ type: 'text', value: 'You are Electron AI' }] }] };
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'create-called', async () => {
|
||||
await w.webContents.executeJavaScript(`LanguageModel.create(${JSON.stringify(options)})`);
|
||||
});
|
||||
|
||||
expect(message.options).to.have.property('signal');
|
||||
delete message.options.signal;
|
||||
expect(message.options).to.deep.equal({ initialPrompts: options.initialPrompts.map(prompt => ({ ...prompt, prefix: false })) });
|
||||
});
|
||||
|
||||
it('passes expectedInputs and expectedOutputs options', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
|
||||
const options = { expectedInputs: [{ type: 'image' }, { type: 'text', languages: ['en', 'fr'] }], expectedOutputs: [{ type: 'text', languages: ['en', 'fr'] }] };
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'create-called', async () => {
|
||||
await w.webContents.executeJavaScript(`LanguageModel.create(${JSON.stringify(options)})`);
|
||||
});
|
||||
|
||||
expect(message.options).to.have.property('signal');
|
||||
delete message.options.signal;
|
||||
expect(message.options).to.deep.equal(options);
|
||||
});
|
||||
|
||||
it('plumbs the abort signal through', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'wait-for-abort' });
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'create-aborted', async () => {
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create({ signal: AbortSignal.timeout(500) }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
|
||||
});
|
||||
|
||||
expect(message).not.null();
|
||||
});
|
||||
|
||||
it('exposes contextUsage and contextWindow on the created model', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage, contextWindow: model.contextWindow }))')).to.deep.equal({ contextUsage: 0, contextWindow: 12345 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('LanguageModel.prompt()', () => {
|
||||
async function expectRejectedWithError (message: string | RegExp, prompt: string, options?: Object) {
|
||||
// Unwrap the error message because NotAllowedError won't serialize
|
||||
if (options) {
|
||||
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt(${JSON.stringify(prompt)}, ${JSON.stringify(options)})).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
|
||||
} else {
|
||||
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt(${JSON.stringify(prompt)})).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
|
||||
}
|
||||
}
|
||||
|
||||
it('rejects when handler returns an invalid value', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 99 });
|
||||
|
||||
await expectRejectedWithError(/error occurred/, 'Test prompt');
|
||||
});
|
||||
|
||||
it('rejects when handler promise rejects', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'reject' });
|
||||
|
||||
await expectRejectedWithError(/error occurred/, 'Test prompt');
|
||||
});
|
||||
|
||||
it('rejects after the model has been destroyed', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.prompt("Test") }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('rejects if the utility process dies during prompt', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
|
||||
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Test")).catch(err => { throw err.message; })');
|
||||
});
|
||||
|
||||
aiHandler.kill();
|
||||
await once(aiHandler, 'exit');
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('rejects if the handler gets unregistered during prompt', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
|
||||
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Test")).catch(err => { throw err.message; })');
|
||||
});
|
||||
|
||||
w.webContents.session.registerLocalAIHandler(null);
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('returns a string response from the handler', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('foobar');
|
||||
});
|
||||
|
||||
it('returns a ReadableStream response from the handler', async () => {
|
||||
await forkAndRegisterHandler('streaming-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('Hello World');
|
||||
});
|
||||
|
||||
it('passes string input to the handler', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
|
||||
await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt(\'hello world\'))');
|
||||
});
|
||||
|
||||
expect(message.input).to.deep.equal([{ role: 'user', content: [{ type: 'text', value: 'hello world' }], prefix: false }]);
|
||||
});
|
||||
|
||||
it('passes LanguageModelMessage[] input to the handler', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
|
||||
const input = [{ role: 'user', content: [{ type: 'text', value: 'hello' }] }, { role: 'assistant', content: [{ type: 'text', value: 'hi' }] }];
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
|
||||
await w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt(${JSON.stringify(input)}))`);
|
||||
});
|
||||
|
||||
expect(message.input).to.deep.equal(input.map(msg => ({ ...msg, prefix: false })));
|
||||
});
|
||||
|
||||
it('passes responseConstraint option to the handler', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
|
||||
const responseConstraint = { type: 'object', properties: { name: { type: 'string' } } };
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
|
||||
await w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt('test', { responseConstraint: ${JSON.stringify(responseConstraint)} }))`);
|
||||
});
|
||||
|
||||
expect(message.options.responseConstraint).to.deep.equal(responseConstraint);
|
||||
});
|
||||
|
||||
it('plumbs the abort signal through', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'prompt-aborted', async () => {
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("test", { signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
|
||||
});
|
||||
|
||||
expect(message).not.null();
|
||||
});
|
||||
|
||||
it('updates contextUsage after a prompt', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage }))')).to.deep.equal({ contextUsage: 0 });
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(async (model) => { await model.prompt("hello world"); return { contextUsage: model.contextUsage } })')).to.deep.equal({ contextUsage: 10 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('LanguageModel.promptStreaming()', () => {
|
||||
const collectStream = 'async (stream) => { const reader = stream.getReader(); let r = ""; while (true) { const { done, value } = await reader.read(); if (done) return r; r += value; } }';
|
||||
|
||||
async function expectRejectedWithError (message: string | RegExp, prompt: string, options?: Object) {
|
||||
const collectStreamFn = collectStream;
|
||||
// Unwrap the error message because NotAllowedError won't serialize
|
||||
if (options) {
|
||||
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStreamFn}; return collect(model.promptStreaming(${JSON.stringify(prompt)}, ${JSON.stringify(options)})); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
|
||||
} else {
|
||||
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStreamFn}; return collect(model.promptStreaming(${JSON.stringify(prompt)})); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
|
||||
}
|
||||
}
|
||||
|
||||
it('rejects when handler returns an invalid value', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 99 });
|
||||
|
||||
await expectRejectedWithError(/error occurred/, 'Test prompt');
|
||||
});
|
||||
|
||||
it('rejects when ReadableStream returns an invalid value', async () => {
|
||||
await forkAndRegisterHandler('buggy-streaming-language-model.js');
|
||||
|
||||
await expectRejectedWithError(/has been destroyed/, 'Test prompt');
|
||||
});
|
||||
|
||||
it('rejects when handler promise rejects', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'reject' });
|
||||
|
||||
await expectRejectedWithError(/error occurred/, 'Test prompt');
|
||||
});
|
||||
|
||||
it('rejects after the model has been destroyed', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { model.destroy(); const collect = ${collectStream}; return collect(model.promptStreaming("Test")); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('rejects if the utility process dies during prompt', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
|
||||
return w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Test")); }).catch(err => { throw err.message; })`);
|
||||
});
|
||||
|
||||
aiHandler.kill();
|
||||
await once(aiHandler, 'exit');
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('rejects if the handler gets unregistered during prompt', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
|
||||
return w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Test")); }).catch(err => { throw err.message; })`);
|
||||
});
|
||||
|
||||
w.webContents.session.registerLocalAIHandler(null);
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('returns a string response from the handler', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Hi")); })`)).to.equal('foobar');
|
||||
});
|
||||
|
||||
it('returns a ReadableStream response from the handler', async () => {
|
||||
await forkAndRegisterHandler('streaming-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Hi")); })`)).to.equal('Hello World');
|
||||
});
|
||||
|
||||
it('passes string input to the handler', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
|
||||
await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming('hello world')); })`);
|
||||
});
|
||||
|
||||
expect(message.input).to.deep.equal([{ role: 'user', content: [{ type: 'text', value: 'hello world' }], prefix: false }]);
|
||||
});
|
||||
|
||||
it('passes LanguageModelMessage[] input to the handler', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
|
||||
const input = [{ role: 'user', content: [{ type: 'text', value: 'hello' }] }, { role: 'assistant', content: [{ type: 'text', value: 'hi' }] }];
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
|
||||
await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming(${JSON.stringify(input)})); })`);
|
||||
});
|
||||
|
||||
expect(message.input).to.deep.equal(input.map(msg => ({ ...msg, prefix: false })));
|
||||
});
|
||||
|
||||
it('passes responseConstraint option to the handler', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
|
||||
const responseConstraint = { type: 'object', properties: { name: { type: 'string' } } };
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
|
||||
await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming('test', { responseConstraint: ${JSON.stringify(responseConstraint)} })); })`);
|
||||
});
|
||||
|
||||
expect(message.options.responseConstraint).to.deep.equal(responseConstraint);
|
||||
});
|
||||
|
||||
it('plumbs the abort signal through', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'prompt-aborted', async () => {
|
||||
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("test", { signal: AbortSignal.timeout(500) })); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(/signal timed out/);
|
||||
});
|
||||
|
||||
expect(message).not.null();
|
||||
});
|
||||
|
||||
it('updates contextUsage after a prompt', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage }))')).to.deep.equal({ contextUsage: 0 });
|
||||
expect(await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; await collect(model.promptStreaming("hello world")); return { contextUsage: model.contextUsage }; })`)).to.deep.equal({ contextUsage: 10 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('LanguageModel.append()', () => {
|
||||
it('rejects when handler promise rejects', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'reject' });
|
||||
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/error occurred/);
|
||||
});
|
||||
|
||||
it('rejects after the model has been destroyed', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.append("Test") }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('rejects if the utility process dies during append', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'append-called', () => {
|
||||
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })');
|
||||
});
|
||||
|
||||
aiHandler.kill();
|
||||
await once(aiHandler, 'exit');
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('rejects if the handler gets unregistered during append', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'append-called', () => {
|
||||
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })');
|
||||
});
|
||||
|
||||
w.webContents.session.registerLocalAIHandler(null);
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('appends a message without producing a response', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })')).to.be.undefined();
|
||||
});
|
||||
|
||||
it('plumbs the abort signal through', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'append-aborted', async () => {
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("test", { signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
|
||||
});
|
||||
|
||||
expect(message).not.null();
|
||||
});
|
||||
|
||||
it('updates contextUsage after append', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage }))')).to.deep.equal({ contextUsage: 0 });
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(async (model) => { await model.append("hello world"); return { contextUsage: model.contextUsage } })')).to.deep.equal({ contextUsage: 5 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('LanguageModel.measureContextUsage()', () => {
|
||||
it('rejects if invalid value returned', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'invalid' });
|
||||
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/usage cannot be calculated/);
|
||||
});
|
||||
|
||||
it('rejects when handler promise rejects', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'reject' });
|
||||
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/usage cannot be calculated/);
|
||||
});
|
||||
|
||||
it('rejects after the model has been destroyed', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.measureContextUsage("Test") }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('rejects if the utility process dies during call', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'measure-called', () => {
|
||||
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err?.message ?? "Unknown Error"; })');
|
||||
});
|
||||
|
||||
aiHandler.kill();
|
||||
await once(aiHandler, 'exit');
|
||||
|
||||
await expect(promise).to.eventually.be.rejected();
|
||||
});
|
||||
|
||||
it('rejects if the handler gets unregistered during call', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'measure-called', () => {
|
||||
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err?.message ?? "Unknown Error"; })');
|
||||
});
|
||||
|
||||
w.webContents.session.registerLocalAIHandler(null);
|
||||
|
||||
await expect(promise).to.eventually.be.rejected();
|
||||
});
|
||||
|
||||
it('returns the token count for the given input', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("hello world"))')).to.equal(42);
|
||||
});
|
||||
|
||||
// TODO(dsanders11): Upstream Chromium issue prevents this test from passing as
|
||||
// there's no Mojo connection to disconnect trip abort signal
|
||||
it.skip('plumbs the abort signal through', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'wait-for-abort' });
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'measure-aborted', async () => {
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("test", { signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
|
||||
});
|
||||
|
||||
expect(message).not.null();
|
||||
});
|
||||
});
|
||||
|
||||
describe('LanguageModel.clone()', () => {
|
||||
it('rejects when clone() returns a non-LanguageModel value', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'invalid' });
|
||||
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/cannot be cloned/);
|
||||
});
|
||||
|
||||
it('rejects when clone() promise rejects', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'reject' });
|
||||
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/cannot be cloned/);
|
||||
});
|
||||
|
||||
it('rejects after the original model has been destroyed', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.clone(); }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
it('rejects if the utility process dies during clone', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'clone-called', () => {
|
||||
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })');
|
||||
});
|
||||
|
||||
aiHandler.kill();
|
||||
await once(aiHandler, 'exit');
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/cannot be cloned/);
|
||||
});
|
||||
|
||||
it('rejects if the handler gets unregistered during clone', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'wait-for-abort' });
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'clone-called', () => {
|
||||
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })');
|
||||
});
|
||||
|
||||
w.webContents.session.registerLocalAIHandler(null);
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/cannot be cloned/);
|
||||
});
|
||||
|
||||
it('returns a new LanguageModel instance', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).then(cloned => cloned instanceof LanguageModel)')).to.equal(true);
|
||||
});
|
||||
|
||||
it('preserves context from the original model', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript(`
|
||||
LanguageModel.create().then(async (model) => {
|
||||
await model.prompt("hello");
|
||||
const cloned = await model.clone();
|
||||
return { contextUsage: cloned.contextUsage, contextWindow: cloned.contextWindow };
|
||||
})
|
||||
`)).to.deep.equal({ contextUsage: 10, contextWindow: 12345 });
|
||||
});
|
||||
|
||||
it('plumbs the abort signal through', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'wait-for-abort' });
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'clone-aborted', async () => {
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone({ signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
|
||||
});
|
||||
|
||||
expect(message).not.null();
|
||||
});
|
||||
});
|
||||
|
||||
describe('LanguageModel.destroy()', () => {
|
||||
it('destroys the model', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'destroy-called', async () => {
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.prompt("Test"); }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
});
|
||||
|
||||
expect(message).not.null();
|
||||
});
|
||||
|
||||
it('aborts any in-progress prompt calls', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
|
||||
|
||||
await w.webContents.executeJavaScript('LanguageModel.create().then(model => { window._model = model; })');
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
|
||||
return w.webContents.executeJavaScript('window._model.prompt("Test").catch(err => { throw err.message; })');
|
||||
});
|
||||
const message = await listenForMessage(aiHandler, 'prompt-aborted', async () => {
|
||||
await w.webContents.executeJavaScript('window._model.destroy()');
|
||||
});
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
expect(message).not.null();
|
||||
});
|
||||
|
||||
it('aborts any in-progress append calls', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
|
||||
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
|
||||
|
||||
await w.webContents.executeJavaScript('LanguageModel.create().then(model => { window._model = model; })');
|
||||
|
||||
const { promise } = await waitForMessage(aiHandler, 'append-called', () => {
|
||||
return w.webContents.executeJavaScript('window._model.append("Test").catch(err => { throw err.message; })');
|
||||
});
|
||||
const message = await listenForMessage(aiHandler, 'append-aborted', async () => {
|
||||
await w.webContents.executeJavaScript('window._model.destroy()');
|
||||
});
|
||||
|
||||
await w.webContents.executeJavaScript('window._model.destroy()');
|
||||
|
||||
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
|
||||
expect(message).not.null();
|
||||
});
|
||||
|
||||
it('can be called multiple times without error', async () => {
|
||||
await forkAndRegisterHandler('basic-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); model.destroy(); model.destroy(); return true; })')).to.equal(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('setPromptAPIHandler()', () => {
|
||||
it('rejects if handler returns a promise', async () => {
|
||||
await forkAndRegisterHandler('promise-handler-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
|
||||
it('rejects if handler returns a non-class value', async () => {
|
||||
await forkAndRegisterHandler('non-class-handler-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
|
||||
it('rejects if handler returns a class not extending LanguageModel', async () => {
|
||||
await forkAndRegisterHandler('non-language-model-handler.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
|
||||
it('receives webContentsId in the details object', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'handler-called', async () => {
|
||||
await w.webContents.executeJavaScript('LanguageModel.availability()');
|
||||
});
|
||||
|
||||
expect(message.details).to.have.property('webContentsId', w.webContents.id);
|
||||
});
|
||||
|
||||
it('receives securityOrigin in the details object', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'handler-called', async () => {
|
||||
await w.webContents.executeJavaScript('LanguageModel.availability()');
|
||||
});
|
||||
|
||||
expect(message.details).to.have.property('securityOrigin');
|
||||
expect(message.details.securityOrigin).to.be.a('string').and.not.be.empty();
|
||||
});
|
||||
|
||||
it('is called once per webContentsId and securityOrigin pair', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
|
||||
|
||||
const message = await listenForMessage(aiHandler, 'handler-called', async () => {
|
||||
await w.webContents.executeJavaScript('LanguageModel.availability()');
|
||||
});
|
||||
|
||||
expect(message.callCount).to.equal(1);
|
||||
|
||||
// Calling availability again should not trigger the handler again
|
||||
await w.webContents.executeJavaScript('LanguageModel.availability()');
|
||||
|
||||
// Create a second window with the same session - should trigger handler again (different webContentsId)
|
||||
const w2 = new BrowserWindow({
|
||||
show: false,
|
||||
webPreferences: {
|
||||
session: w.webContents.session,
|
||||
enableBlinkFeatures: 'AIPromptAPI'
|
||||
}
|
||||
});
|
||||
await w2.loadFile(path.join(fixtures, 'api', 'blank.html'));
|
||||
|
||||
const message2 = await listenForMessage(aiHandler, 'handler-called', async () => {
|
||||
await w2.webContents.executeJavaScript('LanguageModel.availability()');
|
||||
});
|
||||
|
||||
expect(message2.callCount).to.equal(2);
|
||||
});
|
||||
|
||||
it('can be cleared by calling with null', async () => {
|
||||
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
|
||||
// Clear the handler inside the utility process
|
||||
await sendControllableMessage(aiHandler, { command: 'clear-handler' });
|
||||
|
||||
// Existing Prompt API bindings should still work until the page is reloaded
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
|
||||
// Load a new page to get a fresh Prompt API binding
|
||||
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
|
||||
|
||||
// Should be unavailable since setPromptAPIHandler(null) was called in the utility process
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
});
|
||||
|
||||
describe('LanguageModel base class', () => {
|
||||
it('provides default no-op implementations for all methods', async () => {
|
||||
await forkAndRegisterHandler('default-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('');
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Hi"))')).to.be.undefined();
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Hi"))')).to.equal(0);
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).then(cloned => cloned instanceof LanguageModel)')).to.equal(true);
|
||||
});
|
||||
|
||||
it('can use the base LanguageModel class directly without subclassing', async () => {
|
||||
await forkAndRegisterHandler('default-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model instanceof LanguageModel)')).to.equal(true);
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage, contextWindow: model.contextWindow }))')).to.deep.equal({ contextUsage: 0, contextWindow: 0 });
|
||||
});
|
||||
});
|
||||
|
||||
describe('session isolation', () => {
|
||||
it('applies to all windows using the same session', async () => {
|
||||
await forkAndRegisterHandler('default-language-model.js');
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
|
||||
const w2 = new BrowserWindow({
|
||||
show: false,
|
||||
webPreferences: {
|
||||
enableBlinkFeatures: 'AIPromptAPI'
|
||||
}
|
||||
});
|
||||
await w2.loadFile(path.join(fixtures, 'api', 'blank.html'));
|
||||
|
||||
expect(await w2.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
});
|
||||
|
||||
it('different sessions can use different handler processes', async () => {
|
||||
const ses1 = session.fromPartition('ai-isolation-1');
|
||||
const ses2 = session.fromPartition('ai-isolation-2');
|
||||
|
||||
const w1 = new BrowserWindow({
|
||||
show: false,
|
||||
webPreferences: {
|
||||
session: ses1,
|
||||
enableBlinkFeatures: 'AIPromptAPI'
|
||||
}
|
||||
});
|
||||
const w2 = new BrowserWindow({
|
||||
show: false,
|
||||
webPreferences: {
|
||||
session: ses2,
|
||||
enableBlinkFeatures: 'AIPromptAPI'
|
||||
}
|
||||
});
|
||||
|
||||
await Promise.all([
|
||||
w1.loadFile(path.join(fixtures, 'api', 'blank.html')),
|
||||
w2.loadFile(path.join(fixtures, 'api', 'blank.html'))
|
||||
]);
|
||||
|
||||
const aiHandler1 = utilityProcess.fork(getFixturePath('basic-language-model.js'));
|
||||
await once(aiHandler1, 'spawn');
|
||||
ses1.registerLocalAIHandler(aiHandler1);
|
||||
|
||||
const aiHandler2 = utilityProcess.fork(getFixturePath('default-language-model.js'));
|
||||
await once(aiHandler2, 'spawn');
|
||||
ses2.registerLocalAIHandler(aiHandler2);
|
||||
|
||||
try {
|
||||
// basic-language-model returns 'foobar'
|
||||
expect(await w1.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('foobar');
|
||||
// default-language-model returns ''
|
||||
expect(await w2.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('');
|
||||
} finally {
|
||||
ses1.registerLocalAIHandler(null);
|
||||
ses2.registerLocalAIHandler(null);
|
||||
}
|
||||
});
|
||||
|
||||
it('clearing one session handler does not affect another', async () => {
|
||||
const ses1 = session.fromPartition('ai-isolation-clear-1');
|
||||
const ses2 = session.fromPartition('ai-isolation-clear-2');
|
||||
|
||||
const w1 = new BrowserWindow({
|
||||
show: false,
|
||||
webPreferences: {
|
||||
session: ses1,
|
||||
enableBlinkFeatures: 'AIPromptAPI'
|
||||
}
|
||||
});
|
||||
const w2 = new BrowserWindow({
|
||||
show: false,
|
||||
webPreferences: {
|
||||
session: ses2,
|
||||
enableBlinkFeatures: 'AIPromptAPI'
|
||||
}
|
||||
});
|
||||
|
||||
await Promise.all([
|
||||
w1.loadFile(path.join(fixtures, 'api', 'blank.html')),
|
||||
w2.loadFile(path.join(fixtures, 'api', 'blank.html'))
|
||||
]);
|
||||
|
||||
const aiHandler1 = utilityProcess.fork(getFixturePath('basic-language-model.js'));
|
||||
await once(aiHandler1, 'spawn');
|
||||
ses1.registerLocalAIHandler(aiHandler1);
|
||||
|
||||
const aiHandler2 = utilityProcess.fork(getFixturePath('basic-language-model.js'));
|
||||
await once(aiHandler2, 'spawn');
|
||||
ses2.registerLocalAIHandler(aiHandler2);
|
||||
|
||||
try {
|
||||
// Both should be available
|
||||
expect(await w1.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
expect(await w2.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
|
||||
// Clear handler for session 1
|
||||
ses1.registerLocalAIHandler(null);
|
||||
|
||||
// Session 1 should be unavailable, session 2 should still be available
|
||||
expect(await w1.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
expect(await w2.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
} finally {
|
||||
ses1.registerLocalAIHandler(null);
|
||||
ses2.registerLocalAIHandler(null);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,4 @@
|
||||
import { app, session, BrowserWindow, net, ipcMain, Session, webFrameMain, WebFrameMain } from 'electron/main';
|
||||
import { app, session, BrowserWindow, net, ipcMain, Session, utilityProcess, webFrameMain, WebFrameMain } from 'electron/main';
|
||||
|
||||
import * as auth from 'basic-auth';
|
||||
import { expect } from 'chai';
|
||||
@@ -2132,4 +2132,96 @@ describe('session module', () => {
|
||||
expect((await cookies.get({ url: 'https://example.org/', name: 'testdotorg' })).length).to.equal(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe('ses.registerLocalAIHandler()', () => {
|
||||
let w: Electron.BrowserWindow;
|
||||
|
||||
beforeEach(() => {
|
||||
w = new BrowserWindow({
|
||||
show: false,
|
||||
webPreferences: {
|
||||
enableBlinkFeatures: 'AIPromptAPI'
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
w.webContents.session.registerLocalAIHandler(null);
|
||||
closeAllWindows();
|
||||
});
|
||||
|
||||
it('registers a utility process as the AI handler', async () => {
|
||||
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
|
||||
|
||||
const aiHandler = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
|
||||
w.webContents.session.registerLocalAIHandler(aiHandler);
|
||||
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
});
|
||||
|
||||
it('clears the handler when called with null', async () => {
|
||||
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
|
||||
const { session } = w.webContents;
|
||||
|
||||
const aiHandler = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
|
||||
session.registerLocalAIHandler(aiHandler);
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
|
||||
session.registerLocalAIHandler(null);
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
});
|
||||
|
||||
it('prevents new LanguageModel.create() calls after clearing', async () => {
|
||||
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
|
||||
const { session } = w.webContents;
|
||||
|
||||
const aiHandler = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
|
||||
session.registerLocalAIHandler(aiHandler);
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create()')).to.eventually.be.fulfilled();
|
||||
|
||||
session.registerLocalAIHandler(null);
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/unable to create/);
|
||||
});
|
||||
|
||||
it('can re-register a new handler after clearing', async () => {
|
||||
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
|
||||
const { session } = w.webContents;
|
||||
|
||||
const aiHandler1 = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
|
||||
session.registerLocalAIHandler(aiHandler1);
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
|
||||
session.registerLocalAIHandler(null);
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
|
||||
const aiHandler2 = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
|
||||
session.registerLocalAIHandler(aiHandler2);
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
await expect(w.webContents.executeJavaScript('LanguageModel.create()')).to.eventually.be.fulfilled();
|
||||
});
|
||||
|
||||
it('throws when called with a non-UtilityProcess argument', () => {
|
||||
const { session } = w.webContents;
|
||||
|
||||
expect(() => session.registerLocalAIHandler('not a process' as any)).to.throw();
|
||||
expect(() => session.registerLocalAIHandler(42 as any)).to.throw();
|
||||
expect(() => session.registerLocalAIHandler({} as any)).to.throw();
|
||||
});
|
||||
|
||||
it('can register an existing handler again', async () => {
|
||||
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
|
||||
const { session } = w.webContents;
|
||||
|
||||
const aiHandler = utilityProcess.fork(path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), 'default-language-model.js'));
|
||||
|
||||
session.registerLocalAIHandler(aiHandler);
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
|
||||
session.registerLocalAIHandler(null);
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
|
||||
|
||||
session.registerLocalAIHandler(aiHandler);
|
||||
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
28
spec/fixtures/api/local-ai-handler/basic-language-model.js
vendored
Normal file
28
spec/fixtures/api/local-ai-handler/basic-language-model.js
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
const { localAIHandler, LanguageModel } = require('electron/utility');
|
||||
|
||||
localAIHandler.setPromptAPIHandler(() => {
|
||||
const BasicLanguageModel = class extends LanguageModel {
|
||||
static async create () {
|
||||
return new BasicLanguageModel({
|
||||
contextUsage: 0,
|
||||
contextWindow: 12345
|
||||
});
|
||||
}
|
||||
|
||||
async prompt () {
|
||||
this.contextUsage += 10;
|
||||
|
||||
return 'foobar';
|
||||
}
|
||||
|
||||
async append () {
|
||||
this.contextUsage += 5;
|
||||
}
|
||||
|
||||
async measureContextUsage () {
|
||||
return 42;
|
||||
}
|
||||
};
|
||||
|
||||
return BasicLanguageModel;
|
||||
});
|
||||
15
spec/fixtures/api/local-ai-handler/buggy-language-model.js
vendored
Normal file
15
spec/fixtures/api/local-ai-handler/buggy-language-model.js
vendored
Normal file
@@ -0,0 +1,15 @@
|
||||
const { localAIHandler, LanguageModel } = require('electron/utility');
|
||||
|
||||
localAIHandler.setPromptAPIHandler(() => {
|
||||
const BuggyLanguageModel = class extends LanguageModel {
|
||||
static async create () {
|
||||
return 'foobar';
|
||||
}
|
||||
|
||||
static availability () {
|
||||
return 'foobar';
|
||||
}
|
||||
};
|
||||
|
||||
return BuggyLanguageModel;
|
||||
});
|
||||
28
spec/fixtures/api/local-ai-handler/buggy-streaming-language-model.js
vendored
Normal file
28
spec/fixtures/api/local-ai-handler/buggy-streaming-language-model.js
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
const { localAIHandler, LanguageModel } = require('electron/utility');
|
||||
|
||||
const { ReadableStream } = require('node:stream/web');
|
||||
|
||||
localAIHandler.setPromptAPIHandler(() => {
|
||||
const BuggyStreamingLanguageModel = class extends LanguageModel {
|
||||
static async create () {
|
||||
return new BuggyStreamingLanguageModel({
|
||||
contextUsage: 0,
|
||||
contextWindow: 0
|
||||
});
|
||||
}
|
||||
|
||||
async prompt () {
|
||||
this.contextUsage += 10;
|
||||
|
||||
return new ReadableStream({
|
||||
async start (controller) {
|
||||
controller.enqueue('Hello ');
|
||||
controller.enqueue(99);
|
||||
controller.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return BuggyStreamingLanguageModel;
|
||||
});
|
||||
110
spec/fixtures/api/local-ai-handler/controllable-language-model.js
vendored
Normal file
110
spec/fixtures/api/local-ai-handler/controllable-language-model.js
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
const { localAIHandler, LanguageModel } = require('electron/utility');
|
||||
|
||||
const { once } = require('node:events');
|
||||
|
||||
let availabilityState = 'available';
|
||||
let createResponse = null;
|
||||
let promptResponse = 'Hello World';
|
||||
let appendResponse = null;
|
||||
let measureResponse = 100;
|
||||
let cloneResponse = null;
|
||||
|
||||
process.parentPort.on('message', (e) => {
|
||||
const { command, value } = e.data;
|
||||
if (command === 'set-availability') {
|
||||
availabilityState = value;
|
||||
} else if (command === 'set-create') {
|
||||
createResponse = value;
|
||||
} else if (command === 'set-prompt-response') {
|
||||
promptResponse = value;
|
||||
} else if (command === 'set-append-response') {
|
||||
appendResponse = value;
|
||||
} else if (command === 'set-measure-response') {
|
||||
measureResponse = value;
|
||||
} else if (command === 'set-clone-response') {
|
||||
cloneResponse = value;
|
||||
}
|
||||
|
||||
process.parentPort.postMessage('ack');
|
||||
});
|
||||
|
||||
async function waitForAbort (signal, messageType) {
|
||||
await once(signal, 'abort');
|
||||
process.parentPort.postMessage({ type: messageType });
|
||||
throw new Error('Aborted');
|
||||
}
|
||||
|
||||
localAIHandler.setPromptAPIHandler(() => {
|
||||
const ControllableLanguageModel = class extends LanguageModel {
|
||||
static async create (options) {
|
||||
process.parentPort.postMessage({ type: 'create-called', options });
|
||||
if (createResponse === 'reject') {
|
||||
return Promise.reject(new Error('Model is unavailable'));
|
||||
} else if (createResponse === 'wait-for-abort') {
|
||||
await waitForAbort(options.signal, 'create-aborted');
|
||||
}
|
||||
return new ControllableLanguageModel({
|
||||
contextUsage: 0,
|
||||
contextWindow: 0
|
||||
});
|
||||
}
|
||||
|
||||
static async availability (options) {
|
||||
process.parentPort.postMessage({ type: 'availability-called', options });
|
||||
if (availabilityState === 'reject') {
|
||||
return Promise.reject(new Error('Model is unavailable'));
|
||||
}
|
||||
return availabilityState;
|
||||
}
|
||||
|
||||
async prompt (input, options) {
|
||||
process.parentPort.postMessage({ type: 'prompt-called', input, options });
|
||||
if (promptResponse === 'reject') {
|
||||
return Promise.reject(new Error('Model is unavailable'));
|
||||
} else if (promptResponse === 'wait-for-abort') {
|
||||
await waitForAbort(options.signal, 'prompt-aborted');
|
||||
}
|
||||
return promptResponse;
|
||||
}
|
||||
|
||||
async append (input, options) {
|
||||
process.parentPort.postMessage({ type: 'append-called', input, options });
|
||||
if (appendResponse === 'reject') {
|
||||
return Promise.reject(new Error('Append failed'));
|
||||
} else if (appendResponse === 'wait-for-abort') {
|
||||
await waitForAbort(options.signal, 'append-aborted');
|
||||
}
|
||||
}
|
||||
|
||||
async measureContextUsage (input, options) {
|
||||
process.parentPort.postMessage({ type: 'measure-called', input, options });
|
||||
if (measureResponse === 'reject') {
|
||||
return Promise.reject(new Error('Measure failed'));
|
||||
} else if (measureResponse === 'wait-for-abort') {
|
||||
await waitForAbort(options.signal, 'measure-aborted');
|
||||
}
|
||||
return measureResponse;
|
||||
}
|
||||
|
||||
async clone (options) {
|
||||
process.parentPort.postMessage({ type: 'clone-called', options });
|
||||
if (cloneResponse === 'reject') {
|
||||
return Promise.reject(new Error('Clone failed'));
|
||||
} else if (cloneResponse === 'wait-for-abort') {
|
||||
await waitForAbort(options.signal, 'clone-aborted');
|
||||
} else if (cloneResponse === 'invalid') {
|
||||
return 'not-a-language-model';
|
||||
}
|
||||
return new ControllableLanguageModel({
|
||||
contextUsage: this.contextUsage,
|
||||
contextWindow: this.contextWindow
|
||||
});
|
||||
}
|
||||
|
||||
destroy () {
|
||||
process.parentPort.postMessage({ type: 'destroy-called' });
|
||||
}
|
||||
};
|
||||
|
||||
return ControllableLanguageModel;
|
||||
});
|
||||
5
spec/fixtures/api/local-ai-handler/default-language-model.js
vendored
Normal file
5
spec/fixtures/api/local-ai-handler/default-language-model.js
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
const { localAIHandler, LanguageModel } = require('electron/utility');
|
||||
|
||||
localAIHandler.setPromptAPIHandler(() => {
|
||||
return LanguageModel;
|
||||
});
|
||||
18
spec/fixtures/api/local-ai-handler/handler-details-language-model.js
vendored
Normal file
18
spec/fixtures/api/local-ai-handler/handler-details-language-model.js
vendored
Normal file
@@ -0,0 +1,18 @@
|
||||
const { localAIHandler, LanguageModel } = require('electron/utility');
|
||||
|
||||
let callCount = 0;
|
||||
|
||||
process.parentPort.on('message', (e) => {
|
||||
const { command } = e.data;
|
||||
if (command === 'clear-handler') {
|
||||
localAIHandler.setPromptAPIHandler(null);
|
||||
}
|
||||
process.parentPort.postMessage('ack');
|
||||
});
|
||||
|
||||
localAIHandler.setPromptAPIHandler((details) => {
|
||||
callCount++;
|
||||
process.parentPort.postMessage({ type: 'handler-called', details, callCount });
|
||||
|
||||
return LanguageModel;
|
||||
});
|
||||
3
spec/fixtures/api/local-ai-handler/no-language-model.js
vendored
Normal file
3
spec/fixtures/api/local-ai-handler/no-language-model.js
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
process.parentPort.on('message', () => {
|
||||
process.parentPort.postMessage('ack');
|
||||
});
|
||||
5
spec/fixtures/api/local-ai-handler/non-class-handler-language-model.js
vendored
Normal file
5
spec/fixtures/api/local-ai-handler/non-class-handler-language-model.js
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
const { localAIHandler } = require('electron/utility');
|
||||
|
||||
localAIHandler.setPromptAPIHandler(() => {
|
||||
return 'not-a-class';
|
||||
});
|
||||
5
spec/fixtures/api/local-ai-handler/non-language-model-handler.js
vendored
Normal file
5
spec/fixtures/api/local-ai-handler/non-language-model-handler.js
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
const { localAIHandler } = require('electron/utility');
|
||||
|
||||
localAIHandler.setPromptAPIHandler(() => {
|
||||
return class NotALanguageModel {};
|
||||
});
|
||||
5
spec/fixtures/api/local-ai-handler/promise-handler-language-model.js
vendored
Normal file
5
spec/fixtures/api/local-ai-handler/promise-handler-language-model.js
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
const { localAIHandler, LanguageModel } = require('electron/utility');
|
||||
|
||||
localAIHandler.setPromptAPIHandler(() => {
|
||||
return Promise.resolve(LanguageModel);
|
||||
});
|
||||
28
spec/fixtures/api/local-ai-handler/streaming-language-model.js
vendored
Normal file
28
spec/fixtures/api/local-ai-handler/streaming-language-model.js
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
const { localAIHandler, LanguageModel } = require('electron/utility');
|
||||
|
||||
const { ReadableStream } = require('node:stream/web');
|
||||
|
||||
localAIHandler.setPromptAPIHandler(() => {
|
||||
const StreamingLanguageModel = class extends LanguageModel {
|
||||
static async create () {
|
||||
return new StreamingLanguageModel({
|
||||
contextUsage: 0,
|
||||
contextWindow: 0
|
||||
});
|
||||
}
|
||||
|
||||
async prompt () {
|
||||
this.contextUsage += 10;
|
||||
|
||||
return new ReadableStream({
|
||||
async start (controller) {
|
||||
controller.enqueue('Hello ');
|
||||
controller.enqueue('World');
|
||||
controller.close();
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return StreamingLanguageModel;
|
||||
});
|
||||
@@ -20,7 +20,8 @@ import {
|
||||
session,
|
||||
systemPreferences,
|
||||
webContents,
|
||||
TouchBar
|
||||
TouchBar,
|
||||
utilityProcess
|
||||
} from 'electron/main';
|
||||
|
||||
import { clipboard, crashReporter, nativeImage, shell } from 'electron/common';
|
||||
@@ -1266,6 +1267,9 @@ session.defaultSession.webRequest.onBeforeSendHeaders(filter, function (details:
|
||||
callback({ cancel: false, requestHeaders: details.requestHeaders });
|
||||
});
|
||||
|
||||
session.defaultSession.registerLocalAIHandler(utilityProcess.fork(path.join(__dirname, 'ai-handler.js')));
|
||||
session.defaultSession.registerLocalAIHandler(null);
|
||||
|
||||
app.whenReady().then(function () {
|
||||
const protocol = session.defaultSession.protocol;
|
||||
protocol.registerFileProtocol('atom', function (request, callback) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* eslint-disable */
|
||||
|
||||
import { net, systemPreferences } from 'electron/utility';
|
||||
import { localAIHandler, net, systemPreferences, LanguageModel } from 'electron/utility';
|
||||
|
||||
process.parentPort.on('message', (e) => {
|
||||
if (e.data === 'Hello from parent!') {
|
||||
@@ -65,3 +65,25 @@ if (process.platform === 'darwin') {
|
||||
const value2 = systemPreferences.getUserDefault('Foo', 'boolean');
|
||||
console.log(value2);
|
||||
}
|
||||
|
||||
// localAIHandler
|
||||
// https://github.com/electron/electron/blob/main/docs/api/local-ai-handler.md
|
||||
|
||||
localAIHandler.setPromptAPIHandler((details) => {
|
||||
return class MyLanguageModel extends LanguageModel {
|
||||
private details = details;
|
||||
|
||||
static async create() {
|
||||
return new MyLanguageModel({
|
||||
contextUsage: 0,
|
||||
contextWindow: 0,
|
||||
})
|
||||
}
|
||||
|
||||
async prompt () {
|
||||
return "Hello World"
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
localAIHandler.setPromptAPIHandler(null);
|
||||
|
||||
1
typings/internal-ambient.d.ts
vendored
1
typings/internal-ambient.d.ts
vendored
@@ -19,6 +19,7 @@ declare namespace NodeJS {
|
||||
isPDFViewerEnabled(): boolean;
|
||||
isFakeLocationProviderEnabled(): boolean;
|
||||
isPrintingEnabled(): boolean;
|
||||
isPromptAPIEnabled(): boolean;
|
||||
isExtensionsEnabled(): boolean;
|
||||
isComponentBuild(): boolean;
|
||||
}
|
||||
|
||||
1
typings/internal-electron.d.ts
vendored
1
typings/internal-electron.d.ts
vendored
@@ -146,6 +146,7 @@ declare namespace Electron {
|
||||
|
||||
interface Session {
|
||||
_setDisplayMediaRequestHandler: Electron.Session['setDisplayMediaRequestHandler'];
|
||||
_registerLocalAIHandler(handler: ElectronInternal.UtilityProcessWrapper | null): void;
|
||||
}
|
||||
|
||||
type CreateWindowFunction = (options: BrowserWindowConstructorOptions) => WebContents;
|
||||
|
||||
Reference in New Issue
Block a user