Compare commits

..

1 Commits

Author SHA1 Message Date
David Sanders
39aed69a33 feat: implement the Prompt API via localAIHandler
Assisted-by: Claude Opus 4.6
2026-04-03 23:15:18 -07:00
121 changed files with 4642 additions and 1281 deletions

View File

@@ -48,15 +48,19 @@ runs:
shell: bash
run: echo "::add-matcher::src/electron/.github/problem-matchers/clang.json"
- name: Download previous object checksums
shell: bash
uses: dawidd6/action-download-artifact@09b07ec687d10771279a426c79925ee415c12906 # v17
if: ${{ (github.event_name == 'push' || github.event_name == 'pull_request') && inputs.is-asan != 'true' }}
env:
GITHUB_TOKEN: ${{ github.token }}
ARTIFACT_NAME: object-checksums.${{ inputs.artifact-platform }}_${{ inputs.target-arch }}.json
SEARCH_BRANCH: ${{ case(github.event_name == 'push', github.ref_name, github.event.pull_request.base.ref) }}
REPO: ${{ github.repository }}
OUTPUT_PATH: src/previous-object-checksums.json
run: node src/electron/.github/actions/build-electron/download-previous-object-checksums.mjs
with:
name: object_checksums_${{ inputs.artifact-platform }}_${{ inputs.target-arch }}
commit: ${{ case(github.event_name == 'push', github.event.before, github.event.pull_request.base.sha) }}
path: src
if_no_artifact_found: ignore
- name: Move previous object checksums
shell: bash
run: |
if [ -f src/object-checksums_${{ inputs.artifact-platform }}_${{ inputs.target-arch }}.json ]; then
mv src/object-checksums_${{ inputs.artifact-platform }}_${{ inputs.target-arch }}.json src/previous-object-checksums.json
fi
- name: Build Electron ${{ inputs.step-suffix }}
if: ${{ inputs.target-platform != 'win' }}
shell: bash
@@ -77,7 +81,7 @@ runs:
if [ "${{ inputs.is-release }}" = "true" ]; then
NINJA_SUMMARIZE_BUILD=1 e build --target electron:release_build
else
NINJA_SUMMARIZE_BUILD=1 e build --target electron:release_build
NINJA_SUMMARIZE_BUILD=1 e build --target electron:testing_build
fi
cp out/Default/.ninja_log out/electron_ninja_log
node electron/script/check-symlinks.js
@@ -225,19 +229,10 @@ runs:
fi
- name: Generate FFMpeg ${{ inputs.step-suffix }}
shell: bash
if: ${{ inputs.is-release == 'true' }}
run: |
cd src
# Reuse the hermetic mac_sdk_path that `e build` wrote for out/Default so
# out/ffmpeg builds against the same SDK instead of the runner's system Xcode.
# The path has to live under root_build_dir, so copy the symlink tree and
# rewrite Default -> ffmpeg.
MAC_SDK_ARG=""
if [ "$(uname)" = "Darwin" ]; then
mkdir -p out/ffmpeg
cp -a out/Default/xcode_links out/ffmpeg/
MAC_SDK_ARG=$(sed -n 's|^\(mac_sdk_path = "//out/\)Default/|\1ffmpeg/|p' out/Default/args.gn)
fi
gn gen out/ffmpeg --args="import(\"//electron/build/args/ffmpeg.gn\") use_remoteexec=true use_siso=true $MAC_SDK_ARG $GN_EXTRA_ARGS"
gn gen out/ffmpeg --args="import(\"//electron/build/args/ffmpeg.gn\") use_remoteexec=true use_siso=true $GN_EXTRA_ARGS"
e build --target electron:electron_ffmpeg_zip -C ../../out/ffmpeg
- name: Remove Clang problem matcher
shell: bash

View File

@@ -1,82 +0,0 @@
import { Octokit } from '@octokit/rest';
import { writeFileSync } from 'node:fs';
const token = process.env.GITHUB_TOKEN;
const repo = process.env.REPO;
const artifactName = process.env.ARTIFACT_NAME;
const branch = process.env.SEARCH_BRANCH;
const outputPath = process.env.OUTPUT_PATH;
const required = { GITHUB_TOKEN: token, REPO: repo, ARTIFACT_NAME: artifactName, SEARCH_BRANCH: branch, OUTPUT_PATH: outputPath };
const missing = Object.entries(required).filter(([, v]) => !v).map(([k]) => k);
if (missing.length > 0) {
console.error(`Missing required environment variables: ${missing.join(', ')}`);
process.exit(1);
}
const [owner, repoName] = repo.split('/');
const octokit = new Octokit({ auth: token });
async function main () {
console.log(`Searching for artifact '${artifactName}' on branch '${branch}'...`);
// Resolve the "Build" workflow name to an ID, mirroring how `gh run list --workflow` works
// under the hood (it uses /repos/{owner}/{repo}/actions/workflows/{id}/runs).
const { data: workflows } = await octokit.actions.listRepoWorkflows({ owner, repo: repoName });
const buildWorkflow = workflows.workflows.find((w) => w.name === 'Build');
if (!buildWorkflow) {
console.log('Could not find "Build" workflow, continuing without previous checksums');
return;
}
const { data: runs } = await octokit.actions.listWorkflowRuns({
owner,
repo: repoName,
workflow_id: buildWorkflow.id,
branch,
status: 'completed',
event: 'push',
per_page: 20,
exclude_pull_requests: true
});
for (const run of runs.workflow_runs) {
const { data: artifacts } = await octokit.actions.listWorkflowRunArtifacts({
owner,
repo: repoName,
run_id: run.id,
name: artifactName
});
if (artifacts.artifacts.length > 0) {
const artifact = artifacts.artifacts[0];
console.log(`Found artifact in run ${run.id} (artifact ID: ${artifact.id}), downloading...`);
// Non-archived artifacts are still downloaded from the /zip endpoint
const response = await octokit.actions.downloadArtifact({
owner,
repo: repoName,
artifact_id: artifact.id,
archive_format: 'zip'
});
if (response.headers['content-type'] !== 'application/json') {
console.error(`Unexpected content type for artifact download: ${response.headers['content-type']}`);
console.error('Expected application/json, continuing without previous checksums');
return;
}
writeFileSync(outputPath, JSON.stringify(response.data));
console.log('Downloaded previous object checksums successfully');
return;
}
}
console.log(`No previous object checksums found in last ${runs.workflow_runs.length} runs, continuing without them`);
}
main().catch((err) => {
console.error('Failed to download previous object checksums, continuing without them:', err.message);
process.exit(0);
});

View File

@@ -28,7 +28,7 @@ runs:
shell: bash
run: |
node src/electron/script/generate-deps-hash.js
DEPSHASH="v2-src-cache-$(cat src/electron/.depshash)"
DEPSHASH="v1-src-cache-$(cat src/electron/.depshash)"
echo "DEPSHASH=$DEPSHASH" >> $GITHUB_ENV
echo "CACHE_FILE=$DEPSHASH.tar" >> $GITHUB_ENV
if [ "${{ inputs.target-platform }}" = "win" ]; then
@@ -109,7 +109,7 @@ runs:
echo "target_os=['$TARGET_OS']" >> ./.gclient
fi
ELECTRON_DEPOT_TOOLS_WIN_TOOLCHAIN=0 DEPOT_TOOLS_WIN_TOOLCHAIN=0 ELECTRON_USE_THREE_WAY_MERGE_FOR_PATCHES=1 e d gclient sync --with_branch_heads --with_tags
ELECTRON_USE_THREE_WAY_MERGE_FOR_PATCHES=1 e d gclient sync --with_branch_heads --with_tags -vv
if [[ "${{ inputs.is-release }}" != "true" ]]; then
# Re-export all the patches to check if there were changes.
python3 src/electron/script/export_all_patches.py src/electron/patches/config.json
@@ -187,9 +187,7 @@ runs:
shell: bash
run: |
echo "Uncompressed src size: $(du -sh src | cut -f1 -d' ')"
# Named .tar but zstd-compressed; the sas-sidecar's filename allowlist
# only permits .tar/.tgz so we keep the extension and decode on restore.
tar -cf - src | zstd -T0 --long=30 -f -o $CACHE_FILE
tar -cf $CACHE_FILE src
echo "Compressed src to $(du -sh $CACHE_FILE | cut -f1 -d' ')"
cp ./$CACHE_FILE $CACHE_DRIVE/
- name: Persist Src Cache

View File

@@ -27,7 +27,6 @@ runs:
python3 src/tools/clang/scripts/update.py
# Refs https://chromium-review.googlesource.com/c/chromium/src/+/6667681
python3 src/tools/clang/scripts/update.py --package objdump
python3 src/tools/clang/scripts/update.py --package clang-tidy
- name: Fix esbuild
if: ${{ inputs.target-platform != 'linux' }}
uses: ./src/electron/.github/actions/cipd-install

View File

@@ -31,7 +31,7 @@ runs:
fi
mkdir temp-cache
zstd -d --long=30 -c $cache_path | tar -xf - -C temp-cache
tar -xf $cache_path -C temp-cache
echo "Unzipped cache is $(du -sh temp-cache/src | cut -f1)"
if [ -d "temp-cache/src" ]; then

View File

@@ -61,9 +61,9 @@ runs:
echo "Cache is empty - exiting"
exit 1
fi
mkdir temp-cache
zstd -d --long=30 -c $DEPSHASH.tar | tar -xf - -C temp-cache
tar -xf $DEPSHASH.tar -C temp-cache
echo "Unzipped cache is $(du -sh temp-cache/src | cut -f1)"
if [ -d "temp-cache/src" ]; then
@@ -85,17 +85,19 @@ runs:
- name: Unzip and Ensure Src Cache (Windows)
if: ${{ inputs.target-platform == 'win' }}
shell: bash
shell: powershell
run: |
echo "Downloaded cache is $(du -sh $DEPSHASH.tar | cut -f1)"
if [ `du $DEPSHASH.tar | cut -f1` = "0" ]; then
echo "Cache is empty - exiting"
$src_cache = "$env:DEPSHASH.tar"
$cache_size = $(Get-Item $src_cache).length
Write-Host "Downloaded cache is $cache_size"
if ($cache_size -eq 0) {
Write-Host "Cache is empty - exiting"
exit 1
fi
}
mkdir temp-cache
zstd -d --long=30 -c $DEPSHASH.tar | tar -xf - -C temp-cache
rm -f $DEPSHASH.tar
$TEMP_DIR=New-Item -ItemType Directory -Path temp-cache
$TEMP_DIR_PATH = $TEMP_DIR.FullName
C:\ProgramData\Chocolatey\bin\7z.exe -y -snld20 x $src_cache -o"$TEMP_DIR_PATH"
- name: Move Src Cache (Windows)
if: ${{ inputs.target-platform == 'win' }}
@@ -110,6 +112,9 @@ runs:
Write-Host "Relocating Cache"
Remove-Item -Recurse -Force src
Move-Item temp-cache\src src
Write-Host "Deleting zip file"
Remove-Item -Force $src_cache
}
if (-Not (Test-Path "src\third_party\blink")) {
Write-Host "Cache was not correctly restored - exiting"

View File

@@ -7,7 +7,6 @@ name: Clean Source Cache
on:
schedule:
- cron: "0 0 * * *"
workflow_dispatch:
permissions: {}
@@ -17,8 +16,6 @@ jobs:
runs-on: electron-arc-centralus-linux-amd64-32core
permissions:
contents: read
env:
DD_API_KEY: ${{ secrets.DD_API_KEY }}
container:
image: ghcr.io/electron/build:bc2f48b2415a670de18d13605b1cf0eb5fdbaae1
options: --user root
@@ -26,130 +23,12 @@ jobs:
- /mnt/cross-instance-cache:/mnt/cross-instance-cache
- /mnt/win-cache:/mnt/win-cache
steps:
- name: Get Disk Space Before Cleanup
id: disk-before
shell: bash
run: |
echo "Disk space before cleanup:"
df -h /mnt/cross-instance-cache
df -h /mnt/win-cache
CROSS_FREE_BEFORE=$(df -k /mnt/cross-instance-cache | tail -1 | awk '{print $4}')
CROSS_TOTAL=$(df -k /mnt/cross-instance-cache | tail -1 | awk '{print $2}')
WIN_FREE_BEFORE=$(df -k /mnt/win-cache | tail -1 | awk '{print $4}')
WIN_TOTAL=$(df -k /mnt/win-cache | tail -1 | awk '{print $2}')
echo "cross_free_kb=$CROSS_FREE_BEFORE" >> $GITHUB_OUTPUT
echo "cross_total_kb=$CROSS_TOTAL" >> $GITHUB_OUTPUT
echo "win_free_kb=$WIN_FREE_BEFORE" >> $GITHUB_OUTPUT
echo "win_total_kb=$WIN_TOTAL" >> $GITHUB_OUTPUT
- name: Cleanup Source Cache
shell: bash
run: |
df -h /mnt/cross-instance-cache
find /mnt/cross-instance-cache -type f -mtime +15 -delete
find /mnt/win-cache -type f -mtime +15 -delete
- name: Get Disk Space After Cleanup
id: disk-after
shell: bash
run: |
echo "Disk space after cleanup:"
df -h /mnt/cross-instance-cache
df -h /mnt/win-cache
CROSS_FREE_AFTER=$(df -k /mnt/cross-instance-cache | tail -1 | awk '{print $4}')
WIN_FREE_AFTER=$(df -k /mnt/win-cache | tail -1 | awk '{print $4}')
echo "cross_free_kb=$CROSS_FREE_AFTER" >> $GITHUB_OUTPUT
echo "win_free_kb=$WIN_FREE_AFTER" >> $GITHUB_OUTPUT
- name: Log Disk Space to Datadog
if: ${{ env.DD_API_KEY != '' }}
shell: bash
env:
CROSS_FREE_BEFORE: ${{ steps.disk-before.outputs.cross_free_kb }}
CROSS_FREE_AFTER: ${{ steps.disk-after.outputs.cross_free_kb }}
CROSS_TOTAL: ${{ steps.disk-before.outputs.cross_total_kb }}
WIN_FREE_BEFORE: ${{ steps.disk-before.outputs.win_free_kb }}
WIN_FREE_AFTER: ${{ steps.disk-after.outputs.win_free_kb }}
WIN_TOTAL: ${{ steps.disk-before.outputs.win_total_kb }}
run: |
TIMESTAMP=$(date +%s)
CROSS_FREE_BEFORE_GB=$(awk "BEGIN {printf \"%.2f\", $CROSS_FREE_BEFORE / 1024 / 1024}")
CROSS_FREE_AFTER_GB=$(awk "BEGIN {printf \"%.2f\", $CROSS_FREE_AFTER / 1024 / 1024}")
CROSS_FREED_GB=$(awk "BEGIN {printf \"%.2f\", ($CROSS_FREE_AFTER - $CROSS_FREE_BEFORE) / 1024 / 1024}")
CROSS_TOTAL_GB=$(awk "BEGIN {printf \"%.2f\", $CROSS_TOTAL / 1024 / 1024}")
WIN_FREE_BEFORE_GB=$(awk "BEGIN {printf \"%.2f\", $WIN_FREE_BEFORE / 1024 / 1024}")
WIN_FREE_AFTER_GB=$(awk "BEGIN {printf \"%.2f\", $WIN_FREE_AFTER / 1024 / 1024}")
WIN_FREED_GB=$(awk "BEGIN {printf \"%.2f\", ($WIN_FREE_AFTER - $WIN_FREE_BEFORE) / 1024 / 1024}")
WIN_TOTAL_GB=$(awk "BEGIN {printf \"%.2f\", $WIN_TOTAL / 1024 / 1024}")
echo "cross-instance-cache: free before=${CROSS_FREE_BEFORE_GB}GB, after=${CROSS_FREE_AFTER_GB}GB, freed=${CROSS_FREED_GB}GB, total=${CROSS_TOTAL_GB}GB"
echo "win-cache: free before=${WIN_FREE_BEFORE_GB}GB, after=${WIN_FREE_AFTER_GB}GB, freed=${WIN_FREED_GB}GB, total=${WIN_TOTAL_GB}GB"
curl -s -X POST "https://api.datadoghq.com/api/v2/series" \
-H "Content-Type: application/json" \
-H "DD-API-KEY: ${DD_API_KEY}" \
-d @- << EOF
{
"series": [
{
"metric": "electron.src_cache.disk.free_space_before_cleanup_gb",
"points": [{"timestamp": ${TIMESTAMP}, "value": ${CROSS_FREE_BEFORE_GB}}],
"type": 3,
"unit": "gigabyte",
"tags": ["volume:cross-instance-cache", "platform:linux"]
},
{
"metric": "electron.src_cache.disk.free_space_after_cleanup_gb",
"points": [{"timestamp": ${TIMESTAMP}, "value": ${CROSS_FREE_AFTER_GB}}],
"type": 3,
"unit": "gigabyte",
"tags": ["volume:cross-instance-cache", "platform:linux"]
},
{
"metric": "electron.src_cache.disk.space_freed_gb",
"points": [{"timestamp": ${TIMESTAMP}, "value": ${CROSS_FREED_GB}}],
"type": 3,
"unit": "gigabyte",
"tags": ["volume:cross-instance-cache", "platform:linux"]
},
{
"metric": "electron.src_cache.disk.total_space_gb",
"points": [{"timestamp": ${TIMESTAMP}, "value": ${CROSS_TOTAL_GB}}],
"type": 3,
"unit": "gigabyte",
"tags": ["volume:cross-instance-cache", "platform:linux"]
},
{
"metric": "electron.src_cache.disk.free_space_before_cleanup_gb",
"points": [{"timestamp": ${TIMESTAMP}, "value": ${WIN_FREE_BEFORE_GB}}],
"type": 3,
"unit": "gigabyte",
"tags": ["volume:win-cache", "platform:linux"]
},
{
"metric": "electron.src_cache.disk.free_space_after_cleanup_gb",
"points": [{"timestamp": ${TIMESTAMP}, "value": ${WIN_FREE_AFTER_GB}}],
"type": 3,
"unit": "gigabyte",
"tags": ["volume:win-cache", "platform:linux"]
},
{
"metric": "electron.src_cache.disk.space_freed_gb",
"points": [{"timestamp": ${TIMESTAMP}, "value": ${WIN_FREED_GB}}],
"type": 3,
"unit": "gigabyte",
"tags": ["volume:win-cache", "platform:linux"]
},
{
"metric": "electron.src_cache.disk.total_space_gb",
"points": [{"timestamp": ${TIMESTAMP}, "value": ${WIN_TOTAL_GB}}],
"type": 3,
"unit": "gigabyte",
"tags": ["volume:win-cache", "platform:linux"]
}
]
}
EOF
echo "Disk space metrics logged to Datadog"
find /mnt/win-cache -type f -mtime +15 -delete
df -h /mnt/win-cache

View File

@@ -35,7 +35,7 @@ jobs:
- name: Generate DEPS Hash
run: |
node src/electron/script/generate-deps-hash.js
DEPSHASH=v2-src-cache-$(cat src/electron/.depshash)
DEPSHASH=v1-src-cache-$(cat src/electron/.depshash)
echo "DEPSHASH=$DEPSHASH" >> $GITHUB_ENV
echo "CACHE_PATH=$DEPSHASH.tar" >> $GITHUB_ENV
- name: Restore src cache via AKS

View File

@@ -156,7 +156,7 @@ jobs:
- name: Generate DEPS Hash
run: |
node src/electron/script/generate-deps-hash.js
DEPSHASH=v2-src-cache-$(cat src/electron/.depshash)
DEPSHASH=v1-src-cache-$(cat src/electron/.depshash)
echo "DEPSHASH=$DEPSHASH" >> $GITHUB_ENV
echo "CACHE_PATH=$DEPSHASH.tar" >> $GITHUB_ENV
- name: Restore src cache via AZCopy

View File

@@ -80,7 +80,7 @@ jobs:
- name: Generate DEPS Hash
run: |
node src/electron/script/generate-deps-hash.js
DEPSHASH=v2-src-cache-$(cat src/electron/.depshash)
DEPSHASH=v1-src-cache-$(cat src/electron/.depshash)
echo "DEPSHASH=$DEPSHASH" >> $GITHUB_ENV
echo "CACHE_PATH=$DEPSHASH.tar" >> $GITHUB_ENV
- name: Restore src cache via AZCopy

View File

@@ -81,7 +81,7 @@ jobs:
- name: Generate DEPS Hash
run: |
node src/electron/script/generate-deps-hash.js
DEPSHASH=v2-src-cache-$(cat src/electron/.depshash)
DEPSHASH=v1-src-cache-$(cat src/electron/.depshash)
echo "DEPSHASH=$DEPSHASH" >> $GITHUB_ENV
echo "CACHE_PATH=$DEPSHASH.tar" >> $GITHUB_ENV
- name: Restore src cache via AZCopy

View File

@@ -165,7 +165,7 @@ jobs:
- name: Generate DEPS Hash
run: |
node src/electron/script/generate-deps-hash.js
DEPSHASH=v2-src-cache-$(cat src/electron/.depshash)
DEPSHASH=v1-src-cache-$(cat src/electron/.depshash)
echo "DEPSHASH=$DEPSHASH" >> $GITHUB_ENV
echo "CACHE_PATH=$DEPSHASH.tar" >> $GITHUB_ENV
- name: Restore src cache via AZCopy

View File

@@ -72,7 +72,7 @@ jobs:
Hello @${{ github.event.pull_request.user.login }}. Due to the high amount of AI spam PRs we receive, if a PR is detected to be majority AI-generated without disclosure and untested, we will automatically close the PR.
We welcome the use of AI tools, as long as the PR meets our quality standards and has clearly been built and tested. If you believe your PR was closed in error, we welcome you to resubmit. However, please read our [CONTRIBUTING.md](https://github.com/electron/electron/blob/main/CONTRIBUTING.md) and [AI Tool Policy](https://github.com/electron/governance/blob/main/policy/ai.md) carefully before reopening. Thanks for your contribution.
We welcome the use of AI tools, as long as the PR meets our quality standards and has clearly been built and tested. If you believe your PR was closed in error, we welcome you to resubmit. However, please read our [CONTRIBUTING.md](http://contributing.md/) carefully before reopening. Thanks for your contribution.
- name: Close the pull request
env:
GITHUB_TOKEN: ${{ steps.generate-token.outputs.token }}

View File

@@ -611,6 +611,17 @@ source_set("electron_lib") {
]
}
if (enable_prompt_api) {
sources += [
"shell/browser/ai/proxying_ai_manager.cc",
"shell/browser/ai/proxying_ai_manager.h",
"shell/utility/ai/utility_ai_language_model.cc",
"shell/utility/ai/utility_ai_language_model.h",
"shell/utility/ai/utility_ai_manager.cc",
"shell/utility/ai/utility_ai_manager.h",
]
}
if (is_mac) {
# Disable C++ modules to resolve linking error when including MacOS SDK
# headers from third_party/electron_node/deps/uv/include/uv/darwin.h

View File

@@ -12,6 +12,7 @@ buildflag_header("buildflags") {
"ENABLE_PDF_VIEWER=$enable_pdf_viewer",
"ENABLE_ELECTRON_EXTENSIONS=$enable_electron_extensions",
"ENABLE_BUILTIN_SPELLCHECKER=$enable_builtin_spellchecker",
"ENABLE_PROMPT_API=$enable_prompt_api",
"OVERRIDE_LOCATION_PROVIDER=$enable_fake_location_provider",
]

View File

@@ -17,6 +17,9 @@ declare_args() {
# Enable Spellchecker support
enable_builtin_spellchecker = true
# Enable Prompt API support.
enable_prompt_api = true
# The version of Electron.
# Packagers and vendor builders should set this in gn args to avoid running
# the script that reads git tag.

View File

@@ -0,0 +1,85 @@
## Class: LanguageModel
> Implement local AI language models
Process: [Utility](../glossary.md#utility-process)
### `new LanguageModel(initialState)`
* `initialState` Object
* `contextUsage` number
* `contextWindow` number
> [!NOTE]
> Do not use this constructor directly outside of the class itself, as it will not be properly connected to the `localAIHandler`
### Static Methods
The `LanguageModel` class has the following static methods:
#### `LanguageModel.create(options)` _Experimental_
* `options` [LanguageModelCreateOptions](structures/language-model-create-options.md)
Returns `Promise<LanguageModel>`. Creates a new `LanguageModel` with the provided `options`.
#### `LanguageModel.availability([options])` _Experimental_
* `options` [LanguageModelCreateCoreOptions](structures/language-model-create-core-options.md) (optional)
Returns `Promise<string>`
Determines the availability of the language model and returns one of the following strings:
* `available`
* `downloadable`
* `downloading`
* `unavailable`
### Instance Properties
The following properties are available on instances of `LanguageModel`:
#### `languageModel.contextUsage` _Experimental_
A `number` representing how many tokens are currently in the context window.
#### `languageModel.contextWindow` _Experimental_
A `number` representing the size of the context window, in tokens.
### Instance Methods
The following methods are available on instances of `LanguageModel`:
#### `languageModel.prompt(input, options)` _Experimental_
* `input` [LanguageModelMessage[]](structures/language-model-message.md)
* `options` [LanguageModelPromptOptions](structures/language-model-prompt-options.md)
Returns `Promise<string> | Promise<import('stream/web').ReadableStream<string>>`. Prompt the model for a response.
#### `languageModel.append(input, options)` _Experimental_
* `input` [LanguageModelMessage[]](structures/language-model-message.md)
* `options` [LanguageModelAppendOptions](structures/language-model-append-options.md)
Returns `Promise<undefined>`. Append a message without prompting for a response.
#### `languageModel.measureContextUsage(input, options)` _Experimental_
* `input` [LanguageModelMessage[]](structures/language-model-message.md)
* `options` [LanguageModelPromptOptions](structures/language-model-prompt-options.md)
Returns `Promise<number>`. Measure how many tokens the input would use.
#### `languageModel.clone(options)` _Experimental_
* `options` [LanguageModelCloneOptions](structures/language-model-clone-options.md)
Returns `Promise<LanguageModel>`. Clones the `LanguageModel` such that the
context and initial prompt should be preserved.
#### `languageModel.destroy()` _Experimental_
Destroys the model, and any ongoing executions are aborted.

View File

@@ -0,0 +1,26 @@
# localAIHandler
> Proxy built-in AI APIs to a local LLM implementation
Process: [Utility](../glossary.md#utility-process)
This module is intended to be used by a script registered to a session via
[`ses.registerLocalAIHandler(handler)`](./session.md#sesregisterlocalaihandlerhandler-experimental)
## Methods
The `localAIHandler` module has the following methods:
#### `localAIHandler.setPromptAPIHandler(handler)` _Experimental_
* `handler` Function\<typeof [LanguageModel](language-model.md)\> | null
* `details` Object
* `webContentsId` Integer - The [unique id](web-contents.md#contentsid-readonly) of
the [WebContents](web-contents.md) calling the Prompt API.
* `securityOrigin` string - Origin of the page calling the Prompt API.
Sets the handler for new Prompt API binding requests from the renderer process. This happens
once per pair of `webContentsId` and `securityOrigin`. Clearing the handler by calling
`setPromptAPIHandler(null)` will prevent new Prompt API sessions from being started,
but will not invalidate existing ones. If you want to invalidate existing Prompt API sessions,
clear the local AI handler for the session using `ses.registerLocalAIHandler(null)`.

View File

@@ -1632,6 +1632,12 @@ This method clears more types of data and is more thorough than the
For more information, refer to Chromium's [`BrowsingDataRemover` interface][browsing-data-remover].
#### `ses.registerLocalAIHandler(handler)` _Experimental_
* `handler` [UtilityProcess](utility-process.md#class-utilityprocess) | null
Registers a local AI handler `UtilityProcess`. To clear the handler, call `registerLocalAIHandler(null)`, which will disconnect any existing Prompt API sessions and destroy any `LanguageModel` instances.
### Instance Properties
The following properties are available on instances of `Session`:

View File

@@ -0,0 +1,3 @@
# LanguageModelAppendOptions Object
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)

View File

@@ -0,0 +1,3 @@
# LanguageModelCloneOptions Object
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)

View File

@@ -0,0 +1,4 @@
# LanguageModelCreateCoreOptions Object
* `expectedInputs` [LanguageModelExpected[]](language-model-expected.md) (optional)
* `expectedOutputs` [LanguageModelExpected[]](language-model-expected.md) (optional)

View File

@@ -0,0 +1,4 @@
# LanguageModelCreateOptions Object extends `LanguageModelCreateCoreOptions`
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)
* `initialPrompts` [LanguageModelMessage[]](language-model-message.md) (optional)

View File

@@ -0,0 +1,7 @@
# LanguageModelExpected Object
* `type` string - Can be one of the following values:
* `text`
* `image`
* `audio`
* `languages` string[] (optional)

View File

@@ -0,0 +1,7 @@
# LanguageModelMessageContent Object
* `type` string - Can be one of the following values:
* `text`
* `image`
* `audio`
* `value` ArrayBuffer | string

View File

@@ -0,0 +1,8 @@
# LanguageModelMessage Object
* `role` string - Can be one of the following values:
* `system`
* `user`
* `assistant`
* `content` [LanguageModelMessageContent[]](language-model-message-content.md)
* `prefix` boolean (optional)

View File

@@ -0,0 +1,4 @@
# LanguageModelPromptOptions Object
* `responseConstraint` Object (optional)
* `signal` [AbortSignal](https://nodejs.org/api/globals.html#globals_class_abortsignal)

View File

@@ -0,0 +1,175 @@
---
title: Local AI Handler
description: Handle built-in AI APIs with a local LLM implementation
slug: local-ai-handler
hide_title: true
---
# Local AI Handler
> **This API is experimental.** It may change or be removed in future Electron releases.
Electron supports [Prompt API](https://github.com/webmachinelearning/prompt-api)
(`LanguageModel`) web API by letting you route calls to a local LLM running in a
[utility process](../api/utility-process.md). Web content calls
`LanguageModel.create()` and `LanguageModel.prompt()` like it would in any
browser, while your Electron app decides which model handles the request.
## How it works
The local AI handler architecture involves three processes:
1. **Main process** — creates `UtilityProcess`, and then registers it to handle
Prompt API calls for a given session via [`ses.registerLocalAIHandler()`](../api/session.md#sesregisterlocalaihandlerhandler-experimental).
2. **Utility process** — runs a script that calls
[`localAIHandler.setPromptAPIHandler()`](../api/local-ai-handler.md#localaihandlersetpromptapihandlerhandler-experimental)
to supply a `LanguageModel` subclass.
3. **Renderer process** — web content uses the standard `LanguageModel` API
(e.g. `LanguageModel.create()`, `model.prompt()`).
When a renderer calls the Prompt API, Electron proxies the request through the
main process to the registered utility process, which invokes your
`LanguageModel` implementation and sends the result back directly to the renderer.
## Prerequisites
The Prompt API Blink feature must be enabled on any `BrowserWindow` that will
use it with the `AIPromptAPI` feature. To enable multi-modal inputs, add the
`AIPromptAPIMultimodalInput` as well.
```js
const win = new BrowserWindow({
webPreferences: {
enableBlinkFeatures: 'AIPromptAPI'
}
})
```
## Quick start
### 1. Create the utility process script
The utility process script registers your `LanguageModel` subclass. The
handler function receives a `details` object with information about the
caller, and must return a class that extends `LanguageModel`.
```js title='ai-handler.js (Utility Process)'
const { localAIHandler, LanguageModel } = require('electron/utility')
localAIHandler.setPromptAPIHandler((details) => {
// details.webContentsId — ID of the calling WebContents
// details.securityOrigin — origin of the calling page
return class MyLanguageModel extends LanguageModel {
static async create (options) {
// options.signal - AbortSignal to cancel the creation of the model
// options.initialPrompts - initial prompts to pass to the language model
return new MyLanguageModel({
contextUsage: 0,
contextWindow: 4096
})
}
static async availability () {
// Return 'available', 'downloadable', 'downloading', or 'unavailable'
return 'available'
}
async prompt (input) {
// input is a string or LanguageModelMessage[]
// Return a string response from your model, or a ReadableStream
// to return a streaming response.
return 'This is a response from your local LLM!'
}
async clone () {
return new MyLanguageModel({
contextUsage: this.contextUsage,
contextWindow: this.contextWindow
})
}
destroy () {
// Clean up model resources
}
}
})
```
### 2. Register the handler in the main process
Fork the utility process and register it as the AI handler for a session:
```js title='main.js (Main Process)'
const { app, BrowserWindow, utilityProcess } = require('electron')
const path = require('node:path')
app.whenReady().then(() => {
// Fork the utility process running your AI handler script
const aiHandler = utilityProcess.fork(path.join(__dirname, 'ai-handler.js'))
// Create a window with the Prompt API enabled
const win = new BrowserWindow({
webPreferences: {
enableBlinkFeatures: 'AIPromptAPI'
}
})
// Connect the AI handler to this session
win.webContents.session.registerLocalAIHandler(aiHandler)
win.loadFile('index.html')
})
```
### 3. Use the Prompt API in your renderer
Your web content can now use the standard `LanguageModel` API, which is a
global available in the renderer:
```html title='index.html (Renderer Process)'
<script>
async function askAI () {
const model = await LanguageModel.create()
const response = await model.prompt('What is Electron?')
document.getElementById('response').textContent = response
}
</script>
<button onclick="askAI()">Ask AI</button>
<p id="response"></p>
```
## Implementing a real model
The quick-start example returns a hardcoded string. A real implementation
would integrate with a local model. See [`electron/llm`](https://github.com/electron/llm)
for an example of using `node-llama-cpp` to wire up GGUF (GPT-Generated Unified Format) models.
## Clearing the handler
To disconnect the AI handler from a session, pass `null`:
```js @ts-type={win:Electron.BrowserWindow}
win.webContents.session.registerLocalAIHandler(null)
```
After clearing, any `LanguageModel.create()` calls from renderers using that
session will fail.
## Security considerations
The `details` object passed to your handler includes `webContentsId` and
`securityOrigin`. Use these to decide whether to handle a request, and
when to reuse a model instance versus providing a fresh instance to
provide proper isolation between origins.
## Further reading
- [`localAIHandler` API reference](../api/local-ai-handler.md)
- [`LanguageModel` API reference](../api/language-model.md)
- [`ses.registerLocalAIHandler()`](../api/session.md#sesregisterlocalaihandlerhandler-experimental)
- [`utilityProcess.fork()`](../api/utility-process.md#utilityprocessforkmodulepath-args-options)
- [`electron/llm`](https://github.com/electron/llm)

View File

@@ -30,6 +30,8 @@ auto_filenames = {
"docs/api/ipc-main-service-worker.md",
"docs/api/ipc-main.md",
"docs/api/ipc-renderer.md",
"docs/api/language-model.md",
"docs/api/local-ai-handler.md",
"docs/api/menu-item.md",
"docs/api/menu.md",
"docs/api/message-channel-main.md",
@@ -108,6 +110,14 @@ auto_filenames = {
"docs/api/structures/jump-list-item.md",
"docs/api/structures/keyboard-event.md",
"docs/api/structures/keyboard-input-event.md",
"docs/api/structures/language-model-append-options.md",
"docs/api/structures/language-model-clone-options.md",
"docs/api/structures/language-model-create-core-options.md",
"docs/api/structures/language-model-create-options.md",
"docs/api/structures/language-model-expected.md",
"docs/api/structures/language-model-message-content.md",
"docs/api/structures/language-model-message.md",
"docs/api/structures/language-model-prompt-options.md",
"docs/api/structures/media-access-permission-request.md",
"docs/api/structures/memory-info.md",
"docs/api/structures/memory-usage-details.md",
@@ -400,6 +410,8 @@ auto_filenames = {
"lib/common/init.ts",
"lib/common/webpack-globals-provider.ts",
"lib/utility/api/exports/electron.ts",
"lib/utility/api/language-model.ts",
"lib/utility/api/local-ai-handler.ts",
"lib/utility/api/module-list.ts",
"lib/utility/api/net.ts",
"lib/utility/init.ts",

View File

@@ -748,6 +748,8 @@ filenames = {
"shell/services/node/node_service.h",
"shell/services/node/parent_port.cc",
"shell/services/node/parent_port.h",
"shell/utility/api/electron_api_local_ai_handler.cc",
"shell/utility/api/electron_api_local_ai_handler.h",
"shell/utility/electron_content_utility_client.cc",
"shell/utility/electron_content_utility_client.h",
]

View File

@@ -2,7 +2,7 @@ import { fetchWithSession } from '@electron/internal/browser/api/net-fetch';
import { addIpcDispatchListeners } from '@electron/internal/browser/ipc-dispatch';
import * as deprecate from '@electron/internal/common/deprecate';
import { net } from 'electron/main';
import { net, type UtilityProcess } from 'electron/main';
const { fromPartition, fromPath, Session } = process._linkedBinding('electron_browser_session');
const { isDisplayMediaSystemPickerAvailable } = process._linkedBinding('electron_browser_desktop_capturer');
@@ -111,6 +111,12 @@ Session.prototype.removeExtension = deprecate.moveAPI(
'session.extensions.removeExtension'
);
Session.prototype.registerLocalAIHandler = function (handler: UtilityProcess | null) {
// We need to unwrap the userland `ForkUtilityProcess` object and get the underlying
// `ElectronInternal.UtilityProcessWrapper` before we call the C++ function
return this._registerLocalAIHandler(handler !== null ? (handler as any)._unwrapHandle() : null);
};
export default {
fromPartition,
fromPath,

View File

@@ -131,6 +131,10 @@ class ForkUtilityProcess extends EventEmitter implements Electron.UtilityProcess
return this.#stderr;
}
_unwrapHandle () {
return this.#handle;
}
postMessage (message: any, transfer?: MessagePortMain[]) {
if (Array.isArray(transfer)) {
transfer = transfer.map((o: any) => o instanceof MessagePortMain ? o._internalPort : o);

View File

@@ -0,0 +1,44 @@
interface LanguageModelConstructorValues {
contextUsage: number;
contextWindow: number;
}
export default class LanguageModel implements Electron.LanguageModel {
contextUsage: number;
contextWindow: number;
constructor (values: LanguageModelConstructorValues) {
this.contextUsage = values.contextUsage;
this.contextWindow = values.contextWindow;
}
static async create (): Promise<LanguageModel> {
return new LanguageModel({
contextUsage: 0,
contextWindow: 0
});
}
static async availability () {
return 'available';
}
async prompt () {
return '';
}
async append (): Promise<undefined> {}
async measureContextUsage () {
return 0;
}
async clone () {
return new LanguageModel({
contextUsage: this.contextUsage,
contextWindow: this.contextWindow
});
}
destroy () {}
}

View File

@@ -0,0 +1,3 @@
const binding = process._linkedBinding('electron_utility_local_ai_handler');
export const setPromptAPIHandler = binding.setPromptAPIHandler;

View File

@@ -1,5 +1,7 @@
// Utility side modules, please sort alphabetically.
export const utilityNodeModuleList: ElectronInternal.ModuleEntry[] = [
{ name: 'localAIHandler', loader: () => require('./local-ai-handler') },
{ name: 'LanguageModel', loader: () => require('./language-model') },
{ name: 'net', loader: () => require('./net') },
{ name: 'systemPreferences', loader: () => require('@electron/internal/browser/api/system-preferences') }
];

View File

@@ -1,6 +1,8 @@
import LanguageModel from '@electron/internal/utility/api/language-model';
import { ParentPort } from '@electron/internal/utility/parent-port';
import { EventEmitter } from 'events';
import { ReadableStream } from 'stream/web';
import { pathToFileURL } from 'url';
const v8Util = process._linkedBinding('electron_common_v8_util');
@@ -10,6 +12,11 @@ const entryScript: string = v8Util.getHiddenValue(process, '_serviceStartupScrip
// we need to restore it here.
process.argv.splice(1, 1, entryScript);
// These are used by C++ to more easily identify these objects.
v8Util.setHiddenValue(global, 'isReadableStream', (val: unknown) => val instanceof ReadableStream);
v8Util.setHiddenValue(global, 'isLanguageModel', (val: unknown) => val instanceof LanguageModel);
v8Util.setHiddenValue(global, 'isLanguageModelClass', (val: any) => Object.is(val, LanguageModel) || val?.prototype instanceof LanguageModel || false);
// Import common settings.
require('@electron/internal/common/init');

View File

@@ -150,3 +150,4 @@ fix_use_fresh_lazynow_for_onendworkitemimpl_after_didruntask.patch
fix_pulseaudio_stream_and_icon_names.patch
fix_fire_menu_popup_start_for_dynamically_created_aria_menus.patch
feat_allow_enabling_extensions_on_custom_protocols.patch
reject_prompt_api_promises_on_mojo_connection_disconnect.patch

View File

@@ -68,7 +68,7 @@ index f91857eb0b6ad385721b8224100de26dfdd7dd8d..45e8766fcb8d46d8edc3bf8d21d3f826
: PdfRenderSettings::Mode::POSTSCRIPT_LEVEL3;
}
diff --git a/chrome/browser/printing/print_view_manager_base.cc b/chrome/browser/printing/print_view_manager_base.cc
index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e96990d04 100644
index aa79c324af2cec50019bca3bccff5d420fb30ffd..455095a2cd63eabe4f267747070b443f0c49c1e8 100644
--- a/chrome/browser/printing/print_view_manager_base.cc
+++ b/chrome/browser/printing/print_view_manager_base.cc
@@ -80,6 +80,20 @@ namespace printing {
@@ -260,18 +260,12 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
if (prefs && prefs->HasPrefPath(prefs::kPrintRasterizePdfDpi)) {
int value = prefs->GetInteger(prefs::kPrintRasterizePdfDpi);
if (value > 0)
@@ -740,8 +765,28 @@ void PrintViewManagerBase::UpdatePrintSettings(
@@ -740,8 +765,22 @@ void PrintViewManagerBase::UpdatePrintSettings(
}
}
-#if BUILDFLAG(IS_WIN)
- // TODO(crbug.com/40260379): Remove this if the printable areas can be made
+#if BUILDFLAG(ENABLE_OOP_PRINTING)
+ if (ShouldPrintJobOop() && !query_with_ui_client_id().has_value()) {
+ RegisterSystemPrintClient();
+ }
+#endif
+
+ std::unique_ptr<PrinterQuery> query =
+ queue_->CreatePrinterQuery(GetCurrentTargetFrame()->GetGlobalId());
+ auto* query_ptr = query.get();
@@ -291,7 +285,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
// fully available from `PrintBackend::GetPrinterSemanticCapsAndDefaults()`
// for in-browser queries.
if (printer_type == mojom::PrinterType::kLocal) {
@@ -762,8 +807,6 @@ void PrintViewManagerBase::UpdatePrintSettings(
@@ -762,8 +801,6 @@ void PrintViewManagerBase::UpdatePrintSettings(
}
#endif
@@ -300,7 +294,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
}
void PrintViewManagerBase::SetAccessibilityTree(
@@ -779,7 +822,7 @@ void PrintViewManagerBase::SetAccessibilityTree(
@@ -779,7 +816,7 @@ void PrintViewManagerBase::SetAccessibilityTree(
void PrintViewManagerBase::IsPrintingEnabled(
IsPrintingEnabledCallback callback) {
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
@@ -309,7 +303,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
}
void PrintViewManagerBase::ScriptedPrint(mojom::ScriptedPrintParamsPtr params,
@@ -805,7 +848,7 @@ void PrintViewManagerBase::ScriptedPrint(mojom::ScriptedPrintParamsPtr params,
@@ -805,7 +842,7 @@ void PrintViewManagerBase::ScriptedPrint(mojom::ScriptedPrintParamsPtr params,
return;
}
#endif
@@ -318,7 +312,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
std::optional<enterprise_connectors::ContentAnalysisDelegate::Data>
scanning_data = enterprise_data_protection::GetPrintAnalysisData(
web_contents(), enterprise_data_protection::PrintScanningContext::
@@ -835,11 +878,9 @@ void PrintViewManagerBase::PrintingFailed(int32_t cookie,
@@ -835,11 +872,9 @@ void PrintViewManagerBase::PrintingFailed(int32_t cookie,
// destroyed. In such cases the error notification to the user will
// have already been displayed, and a second message should not be
// shown.
@@ -332,7 +326,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
ReleasePrinterQuery();
}
@@ -851,15 +892,33 @@ void PrintViewManagerBase::RemoveTestObserver(TestObserver& observer) {
@@ -851,15 +886,33 @@ void PrintViewManagerBase::RemoveTestObserver(TestObserver& observer) {
test_observers_.RemoveObserver(&observer);
}
@@ -366,7 +360,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
}
void PrintViewManagerBase::RenderFrameDeleted(
@@ -901,13 +960,14 @@ void PrintViewManagerBase::SystemDialogCancelled() {
@@ -901,13 +954,14 @@ void PrintViewManagerBase::SystemDialogCancelled() {
// System dialog was cancelled. Clean up the print job and notify the
// BackgroundPrintingManager.
DCHECK_CURRENTLY_ON(content::BrowserThread::UI);
@@ -382,7 +376,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
}
void PrintViewManagerBase::OnDocDone(int job_id, PrintedDocument* document) {
@@ -921,18 +981,26 @@ void PrintViewManagerBase::OnJobDone() {
@@ -921,18 +975,26 @@ void PrintViewManagerBase::OnJobDone() {
// Printing is done, we don't need it anymore.
// print_job_->is_job_pending() may still be true, depending on the order
// of object registration.
@@ -411,7 +405,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
TerminatePrintJob(true);
}
@@ -942,7 +1010,7 @@ bool PrintViewManagerBase::RenderAllMissingPagesNow() {
@@ -942,7 +1004,7 @@ bool PrintViewManagerBase::RenderAllMissingPagesNow() {
// Is the document already complete?
if (print_job_->document() && print_job_->document()->IsComplete()) {
@@ -420,7 +414,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
return true;
}
@@ -995,7 +1063,10 @@ bool PrintViewManagerBase::SetupNewPrintJob(
@@ -995,7 +1057,10 @@ bool PrintViewManagerBase::SetupNewPrintJob(
// Disconnect the current `print_job_`.
auto weak_this = weak_ptr_factory_.GetWeakPtr();
@@ -432,7 +426,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
if (!weak_this)
return false;
@@ -1015,7 +1086,7 @@ bool PrintViewManagerBase::SetupNewPrintJob(
@@ -1015,7 +1080,7 @@ bool PrintViewManagerBase::SetupNewPrintJob(
#endif
print_job_->AddObserver(*this);
@@ -441,7 +435,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
return true;
}
@@ -1073,7 +1144,7 @@ void PrintViewManagerBase::ReleasePrintJob() {
@@ -1073,7 +1138,7 @@ void PrintViewManagerBase::ReleasePrintJob() {
// Ensure that any residual registration of printing client is released.
// This might be necessary in some abnormal cases, such as the associated
// render process having terminated.
@@ -450,7 +444,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
if (!analyzing_content_) {
UnregisterSystemPrintClient();
}
@@ -1083,6 +1154,11 @@ void PrintViewManagerBase::ReleasePrintJob() {
@@ -1083,6 +1148,11 @@ void PrintViewManagerBase::ReleasePrintJob() {
}
#endif
@@ -462,7 +456,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
if (!print_job_)
return;
@@ -1090,7 +1166,7 @@ void PrintViewManagerBase::ReleasePrintJob() {
@@ -1090,7 +1160,7 @@ void PrintViewManagerBase::ReleasePrintJob() {
// printing_rfh_ should only ever point to a RenderFrameHost with a live
// RenderFrame.
DCHECK(rfh->IsRenderFrameLive());
@@ -471,7 +465,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
}
print_job_->RemoveObserver(*this);
@@ -1132,7 +1208,7 @@ bool PrintViewManagerBase::RunInnerMessageLoop() {
@@ -1132,7 +1202,7 @@ bool PrintViewManagerBase::RunInnerMessageLoop() {
}
bool PrintViewManagerBase::OpportunisticallyCreatePrintJob(int cookie) {
@@ -480,7 +474,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
return true;
if (!cookie) {
@@ -1155,7 +1231,7 @@ bool PrintViewManagerBase::OpportunisticallyCreatePrintJob(int cookie) {
@@ -1155,7 +1225,7 @@ bool PrintViewManagerBase::OpportunisticallyCreatePrintJob(int cookie) {
return false;
}
@@ -489,7 +483,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
// Don't start printing if enterprise checks are being performed to check if
// printing is allowed, or if content analysis is going to take place right
// before starting `print_job_`.
@@ -1286,6 +1362,8 @@ void PrintViewManagerBase::CompleteScriptedPrint(
@@ -1286,6 +1356,8 @@ void PrintViewManagerBase::CompleteScriptedPrint(
auto callback_wrapper = base::BindOnce(
&PrintViewManagerBase::ScriptedPrintReply, weak_ptr_factory_.GetWeakPtr(),
std::move(callback), render_process_host->GetDeprecatedID());
@@ -498,7 +492,7 @@ index aa79c324af2cec50019bca3bccff5d420fb30ffd..0b85598f87673537eccdd0b310e8462e
std::unique_ptr<PrinterQuery> printer_query =
queue()->PopPrinterQuery(params->cookie);
if (!printer_query)
@@ -1296,10 +1374,10 @@ void PrintViewManagerBase::CompleteScriptedPrint(
@@ -1296,10 +1368,10 @@ void PrintViewManagerBase::CompleteScriptedPrint(
params->expected_pages_count, params->has_selection, params->margin_type,
params->is_scripted, !render_process_host->IsPdf(),
base::BindOnce(&OnDidScriptedPrint, queue_, std::move(printer_query),
@@ -620,7 +614,7 @@ index 2a477e820d9f0126a05f86cd44f02c2189275bad..a2e9442ff9f5acf8e301f457b1806251
#if BUILDFLAG(IS_CHROMEOS)
diff --git a/chrome/browser/printing/printer_query_oop.cc b/chrome/browser/printing/printer_query_oop.cc
index dc2a15ab4d784b0b6c85b84a30c3c08a17ed8e3d..e197026e8a7f132c1bf90a0f5f1eabb4f5f064ee 100644
index dc2a15ab4d784b0b6c85b84a30c3c08a17ed8e3d..8facb5981cc421cad6bce71dfa8985b0a3270405 100644
--- a/chrome/browser/printing/printer_query_oop.cc
+++ b/chrome/browser/printing/printer_query_oop.cc
@@ -126,7 +126,7 @@ void PrinterQueryOop::OnDidAskUserForSettings(
@@ -632,28 +626,6 @@ index dc2a15ab4d784b0b6c85b84a30c3c08a17ed8e3d..e197026e8a7f132c1bf90a0f5f1eabb4
// Want the same PrintBackend service as the query so that we use the same
// device context.
print_document_client_id_ =
@@ -189,6 +189,21 @@ void PrinterQueryOop::GetSettingsWithUI(uint32_t document_page_count,
// browser process.
// - Other platforms don't have a system print UI or do not use OOP
// printing, so this does not matter.
+
+ // Apply cached settings to the local printing context so that the in-browser
+ // system print dialog is prefilled with user-specified options (e.g. copies,
+ // collate, duplex). OOP UpdatePrintSettings only applies settings to the
+ // remote service context, not the local one used by the native dialog.
+ if (settings().dpi()) {
+ printing_context()->SetPrintSettings(settings());
+ printing_context()->UpdatePrinterSettings(PrintingContext::PrinterSettings{
+#if BUILDFLAG(IS_MAC)
+ .external_preview = false,
+#endif
+ .show_system_dialog = false,
+ });
+ }
+
PrinterQuery::GetSettingsWithUI(
document_page_count, has_selection, is_scripted,
base::BindOnce(&PrinterQueryOop::OnDidAskUserForSettings,
diff --git a/components/printing/browser/print_manager.cc b/components/printing/browser/print_manager.cc
index 21c81377d32ae8d4185598a7eba88ed1d2063ef0..0767f4e9369e926b1cea99178c1a1975941f1765 100644
--- a/components/printing/browser/print_manager.cc
@@ -703,7 +675,7 @@ index ac2f719be566020d9f41364560c12e6d6d0fe3d8..16d758a6936f66148a196761cfb875f6
PrintingFailed(int32 cookie, PrintFailureReason reason);
diff --git a/components/printing/renderer/print_render_frame_helper.cc b/components/printing/renderer/print_render_frame_helper.cc
index 60b5e83a8bc1ed07970be4cdfdc19962698bd754..dd83b6cfb6e3f916e60f50402014cd931a4d8850 100644
index 60b5e83a8bc1ed07970be4cdfdc19962698bd754..1320f3b10b07b2cee90f39f406604176c7575796 100644
--- a/components/printing/renderer/print_render_frame_helper.cc
+++ b/components/printing/renderer/print_render_frame_helper.cc
@@ -54,6 +54,7 @@
@@ -827,7 +799,7 @@ index 60b5e83a8bc1ed07970be4cdfdc19962698bd754..dd83b6cfb6e3f916e60f50402014cd93
// Check if `this` is still valid.
if (!self)
return;
@@ -2394,29 +2415,47 @@ void PrintRenderFrameHelper::IPCProcessed() {
@@ -2394,29 +2415,43 @@ void PrintRenderFrameHelper::IPCProcessed() {
}
bool PrintRenderFrameHelper::InitPrintSettings(blink::WebLocalFrame* frame,
@@ -863,12 +835,8 @@ index 60b5e83a8bc1ed07970be4cdfdc19962698bd754..dd83b6cfb6e3f916e60f50402014cd93
- : mojom::PrintScalingOption::kSourceSize;
- RecordDebugEvent(settings.params->printed_doc_type ==
+ bool silent = new_settings.FindBool("silent").value_or(false);
+ int margins_type = new_settings.FindInt(kSettingMarginsType)
+ .value_or(static_cast<int>(mojom::MarginType::kDefaultMargins));
+ if (silent &&
+ margins_type == static_cast<int>(mojom::MarginType::kDefaultMargins)) {
+ settings->params->print_scaling_option =
+ mojom::PrintScalingOption::kFitToPrintableArea;
+ if (silent) {
+ settings->params->print_scaling_option = mojom::PrintScalingOption::kFitToPrintableArea;
+ } else {
+ settings->params->print_scaling_option =
+ center_on_paper ? mojom::PrintScalingOption::kCenterShrinkToFitPaper

View File

@@ -0,0 +1,168 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: David Sanders <dsanders11@ucsbalum.com>
Date: Wed, 1 Apr 2026 21:14:38 -0700
Subject: Reject Prompt API promises on Mojo connection disconnect
Without these changes to reject promises when the Mojo connection
disconnects, these promises will hang indefinitely if the Prompt
API handler is killed or unregistered.
This will be upstreamed to Chromium.
Change-Id: I89a6a076ae35cbaf12a93c517223a524bab3dff0
diff --git a/third_party/blink/renderer/modules/ai/language_model.cc b/third_party/blink/renderer/modules/ai/language_model.cc
index c176575f2dcc049e478d5388ae1934aa3bc59786..b8eda3dd8733b1e7e92f70a8c75678df4a0314c8 100644
--- a/third_party/blink/renderer/modules/ai/language_model.cc
+++ b/third_party/blink/renderer/modules/ai/language_model.cc
@@ -190,6 +190,9 @@ class CloneLanguageModelClient
client_remote;
receiver_.Bind(client_remote.InitWithNewPipeAndPassReceiver(),
language_model->GetTaskRunner());
+ receiver_.set_disconnect_handler(
+ BindOnce(&CloneLanguageModelClient::OnConnectionError,
+ WrapWeakPersistent(this)));
language_model_->GetAILanguageModelRemote()->Fork(std::move(client_remote));
}
~CloneLanguageModelClient() override = default;
@@ -232,6 +235,11 @@ class CloneLanguageModelClient
Cleanup();
}
+ void OnConnectionError() {
+ OnError(mojom::blink::AIManagerCreateClientError::kUnableToCreateSession,
+ /*quota_error_info=*/nullptr);
+ }
+
void ResetReceiver() override { receiver_.reset(); }
private:
@@ -262,6 +270,8 @@ class AppendClient : public GarbageCollected<AppendClient>,
mojo::PendingRemote<mojom::blink::ModelStreamingResponder> client_remote;
receiver_.Bind(client_remote.InitWithNewPipeAndPassReceiver(),
language_model->GetTaskRunner());
+ receiver_.set_disconnect_handler(
+ BindOnce(&AppendClient::OnConnectionError, WrapWeakPersistent(this)));
language_model_->GetAILanguageModelRemote()->Append(
std::move(prompts), std::move(client_remote));
}
@@ -317,6 +327,11 @@ class AppendClient : public GarbageCollected<AppendClient>,
Cleanup();
}
+ void OnConnectionError() {
+ OnError(ModelStreamingResponseStatus::kErrorSessionDestroyed,
+ /*quota_error_info=*/nullptr);
+ }
+
void OnStreaming(const String& text) override {
NOTREACHED() << "Append() should not invoke `OnStreaming()`";
}
@@ -761,6 +776,7 @@ void LanguageModel::ExecuteMeasureInputUsage(
ScriptPromiseResolver<IDLDouble>* resolver,
AbortSignal* signal,
Vector<mojom::blink::AILanguageModelPromptPtr> prompts) {
+ auto reject_fn = RejectOnDestruction(resolver, signal);
language_model_remote_->MeasureInputUsage(
std::move(prompts),
BindOnce(
@@ -783,7 +799,8 @@ void LanguageModel::ExecuteMeasureInputUsage(
}
resolver->Resolve(static_cast<double>(usage.value()));
},
- WrapPersistent(resolver), WrapPersistent(signal)));
+ WrapPersistent(resolver), WrapPersistent(signal))
+ .Then(std::move(reject_fn)));
}
bool LanguageModel::ValidateInput(ScriptState* script_state,
diff --git a/third_party/blink/renderer/modules/ai/language_model_create_client.cc b/third_party/blink/renderer/modules/ai/language_model_create_client.cc
index ddc6fcda3ffbdc271bcdebfbd85aa711c063fee2..2b63e9e77dc1ed4a0ed9a687527efdc104974e7a 100644
--- a/third_party/blink/renderer/modules/ai/language_model_create_client.cc
+++ b/third_party/blink/renderer/modules/ai/language_model_create_client.cc
@@ -509,6 +509,11 @@ void LanguageModelCreateClient::OnError(
Cleanup();
}
+void LanguageModelCreateClient::OnConnectionError() {
+ OnError(mojom::blink::AIManagerCreateClientError::kUnableToCreateSession,
+ /*quota_error_info=*/nullptr);
+}
+
void LanguageModelCreateClient::ResetReceiver() {
receiver_.reset();
}
@@ -524,6 +529,8 @@ void LanguageModelCreateClient::OnInitialPromptsResolved(
mojo::PendingRemote<mojom::blink::AIManagerCreateLanguageModelClient>
client_remote;
receiver_.Bind(client_remote.InitWithNewPipeAndPassReceiver(), task_runner_);
+ receiver_.set_disconnect_handler(
+ BindOnce(&LanguageModelCreateClient::OnConnectionError, WrapWeakPersistent(this)));
HeapMojoRemote<mojom::blink::AIManager>& ai_manager_remote =
AIInterfaceProxy::GetAIManagerRemote(GetExecutionContext());
diff --git a/third_party/blink/renderer/modules/ai/language_model_create_client.h b/third_party/blink/renderer/modules/ai/language_model_create_client.h
index 9ed8dfbefeccf1627d56f5ccc315f06071a63e25..7c8e823608883171c115676db151b70eb2fd055d 100644
--- a/third_party/blink/renderer/modules/ai/language_model_create_client.h
+++ b/third_party/blink/renderer/modules/ai/language_model_create_client.h
@@ -49,6 +49,8 @@ class LanguageModelCreateClient
// Process options and create, if the availability result is valid.
void Create(mojom::blink::ModelAvailabilityCheckResult result);
+ void OnConnectionError();
+
// Continue creation after any initial prompts were processed or rejected.
void OnInitialPromptsResolved(
Vector<mojom::blink::AILanguageModelExpectedPtr> expected_inputs,
diff --git a/third_party/blink/renderer/modules/ai/model_execution_responder.cc b/third_party/blink/renderer/modules/ai/model_execution_responder.cc
index 47b65b13adfab4b8f2597a23d38a386915643d1b..fa9b54e1069019a66b8dab6eb0efe5df8b34c11a 100644
--- a/third_party/blink/renderer/modules/ai/model_execution_responder.cc
+++ b/third_party/blink/renderer/modules/ai/model_execution_responder.cc
@@ -84,7 +84,10 @@ class Responder final : public GarbageCollected<Responder>,
mojo::PendingRemote<blink::mojom::blink::ModelStreamingResponder>
BindNewPipeAndPassRemote(
scoped_refptr<base::SequencedTaskRunner> task_runner) {
- return receiver_.BindNewPipeAndPassRemote(task_runner);
+ auto pending_remote = receiver_.BindNewPipeAndPassRemote(task_runner);
+ receiver_.set_disconnect_handler(
+ BindOnce(&Responder::OnConnectionError, WrapWeakPersistent(this)));
+ return pending_remote;
}
// `mojom::blink::ModelStreamingResponder` implementation.
@@ -144,6 +147,11 @@ class Responder final : public GarbageCollected<Responder>,
Cleanup();
}
+ void OnConnectionError() {
+ OnError(ModelStreamingResponseStatus::kErrorSessionDestroyed,
+ /*quota_error_info=*/nullptr);
+ }
+
void RecordResponseStatusMetrics(
mojom::blink::ModelStreamingResponseStatus status) {
base::UmaHistogramEnumeration(
@@ -235,7 +243,10 @@ class StreamingResponder final
mojo::PendingRemote<blink::mojom::blink::ModelStreamingResponder>
BindNewPipeAndPassRemote(
scoped_refptr<base::SequencedTaskRunner> task_runner) {
- return receiver_.BindNewPipeAndPassRemote(task_runner);
+ auto pending_remote = receiver_.BindNewPipeAndPassRemote(task_runner);
+ receiver_.set_disconnect_handler(BindOnce(
+ &StreamingResponder::OnConnectionError, WrapWeakPersistent(this)));
+ return pending_remote;
}
ReadableStream* CreateReadableStream() {
@@ -337,6 +348,11 @@ class StreamingResponder final
Cleanup();
}
+ void OnConnectionError() {
+ OnError(ModelStreamingResponseStatus::kErrorSessionDestroyed,
+ /*quota_error_info=*/nullptr);
+ }
+
void RecordResponseStatusMetrics(
mojom::blink::ModelStreamingResponseStatus status) {
base::UmaHistogramEnumeration(

View File

@@ -2,8 +2,6 @@
set -euo pipefail
export XDG_SESSION_TYPE=wayland
# On a Wayland desktop, the tests will use your active display and compositor.
# To run headlessly in weston like in CI, set WAYLAND_DISPLAY=wayland-99.
export WAYLAND_DISPLAY="${WAYLAND_DISPLAY:-wayland-99}"
if [[ -z "${XDG_RUNTIME_DIR:-}" ]]; then

View File

@@ -1,5 +1,4 @@
spec/parse-features-string-spec.ts
spec/types-spec.ts
spec/version-bump-spec.ts
spec/api-app-spec.ts
spec/api-browser-window-spec.ts
spec/api-app-spec.ts

View File

@@ -0,0 +1,187 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#include "shell/browser/ai/proxying_ai_manager.h"
#include <optional>
#include <utility>
#include "base/functional/bind.h"
#include "base/notimplemented.h"
#include "content/public/browser/browser_context.h"
#include "content/public/browser/render_frame_host.h"
#include "content/public/browser/weak_document_ptr.h"
#include "mojo/public/cpp/bindings/callback_helpers.h"
#include "shell/browser/api/electron_api_session.h"
#include "shell/browser/api/electron_api_web_contents.h"
#include "shell/browser/session_preferences.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_writer.mojom.h"
namespace electron {
ProxyingAIManager::ProxyingAIManager(content::BrowserContext* browser_context,
content::RenderFrameHost* rfh)
: browser_context_(browser_context),
rfh_(rfh ? rfh->GetWeakDocumentPtr() : content::WeakDocumentPtr()) {
auto* session_prefs =
SessionPreferences::FromBrowserContext(browser_context_);
if (session_prefs) {
ai_handler_changed_subscription_ =
session_prefs->AddAIHandlerChangedCallback(
base::BindRepeating(&ProxyingAIManager::OnAIHandlerChanged,
weak_ptr_factory_.GetWeakPtr()));
}
}
ProxyingAIManager::~ProxyingAIManager() = default;
void ProxyingAIManager::OnAIHandlerChanged() {
ai_manager_remote_.reset();
}
void ProxyingAIManager::AddReceiver(
mojo::PendingReceiver<blink::mojom::AIManager> receiver) {
receivers_.Add(this, std::move(receiver));
}
const mojo::Remote<blink::mojom::AIManager>&
ProxyingAIManager::GetAIManagerRemote(const SessionPreferences& session_prefs) {
if (!ai_manager_remote_.is_bound()) {
auto* local_ai_handler = session_prefs.GetLocalAIHandler().get();
if (local_ai_handler) {
auto* rfh = rfh_.AsRenderFrameHostIfValid();
DCHECK(rfh);
auto* web_contents = electron::api::WebContents::From(
content::WebContents::FromRenderFrameHost(rfh));
std::optional<int32_t> web_contents_id;
if (web_contents) {
web_contents_id = web_contents->ID();
}
local_ai_handler->BindAIManager(
web_contents_id, rfh->GetLastCommittedOrigin(),
ai_manager_remote_.BindNewPipeAndPassReceiver());
}
}
return ai_manager_remote_;
}
void ProxyingAIManager::CanCreateLanguageModel(
blink::mojom::AILanguageModelCreateOptionsPtr options,
CanCreateLanguageModelCallback callback) {
auto* session_prefs =
SessionPreferences::FromBrowserContext(browser_context_);
DCHECK(session_prefs);
// Default to unavailable. This ensures the callback is always invoked
// even if there is no registered utility process handler, or the
// process crashes.
auto cb = mojo::WrapCallbackWithDefaultInvokeIfNotRun(
std::move(callback),
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
// Proxy the call through to the utility process
auto& ai_manager = GetAIManagerRemote(*session_prefs);
if (ai_manager.is_bound()) {
ai_manager->CanCreateLanguageModel(std::move(options), std::move(cb));
}
}
void ProxyingAIManager::CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) {
auto* session_prefs =
SessionPreferences::FromBrowserContext(browser_context_);
DCHECK(session_prefs);
// Proxy the call through to the utility process
auto& ai_manager = GetAIManagerRemote(*session_prefs);
if (!ai_manager.is_bound()) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote(std::move(client));
client_remote->OnError(
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession,
/*quota_error_info=*/nullptr);
return;
}
ai_manager->CreateLanguageModel(std::move(client), std::move(options));
}
void ProxyingAIManager::CanCreateSummarizer(
blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void ProxyingAIManager::CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void ProxyingAIManager::GetLanguageModelParams(
GetLanguageModelParamsCallback callback) {
NOTIMPLEMENTED();
}
void ProxyingAIManager::CanCreateWriter(
blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void ProxyingAIManager::CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void ProxyingAIManager::CanCreateRewriter(
blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void ProxyingAIManager::CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void ProxyingAIManager::CanCreateProofreader(
blink::mojom::AIProofreaderCreateOptionsPtr options,
CanCreateProofreaderCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void ProxyingAIManager::CreateProofreader(
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient> client,
blink::mojom::AIProofreaderCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void ProxyingAIManager::AddModelDownloadProgressObserver(
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
observer_remote) {
NOTIMPLEMENTED();
}
} // namespace electron

View File

@@ -0,0 +1,103 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#ifndef ELECTRON_SHELL_BROWSER_AI_PROXYING_AI_MANAGER_H_
#define ELECTRON_SHELL_BROWSER_AI_PROXYING_AI_MANAGER_H_
#include "base/callback_list.h"
#include "base/memory/weak_ptr.h"
#include "base/supports_user_data.h"
#include "content/public/browser/weak_document_ptr.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver_set.h"
#include "shell/browser/session_preferences.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_writer.mojom-forward.h"
namespace content {
class BrowserContext;
class RenderFrameHost;
} // namespace content
namespace electron {
// Owned by the host of the document / service worker via `SupportUserData`.
// The browser-side implementation of `blink::mojom::AIManager`, which
// proxies requests to a utility process if the session has a registered
// handler.
class ProxyingAIManager : public base::SupportsUserData::Data,
public blink::mojom::AIManager {
public:
ProxyingAIManager(content::BrowserContext* browser_context,
content::RenderFrameHost* rfh);
ProxyingAIManager(const ProxyingAIManager&) = delete;
ProxyingAIManager& operator=(const ProxyingAIManager&) = delete;
~ProxyingAIManager() override;
void AddReceiver(mojo::PendingReceiver<blink::mojom::AIManager> receiver);
private:
// Lazily bind the AIManager remote so that the developer can
// set the local AI handler after this class is already created
[[nodiscard]] const mojo::Remote<blink::mojom::AIManager>& GetAIManagerRemote(
const SessionPreferences& session_prefs);
void OnAIHandlerChanged();
// `blink::mojom::AIManager` implementation.
void CanCreateLanguageModel(
blink::mojom::AILanguageModelCreateOptionsPtr options,
CanCreateLanguageModelCallback callback) override;
void CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) override;
void CanCreateSummarizer(blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) override;
void CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) override;
void GetLanguageModelParams(GetLanguageModelParamsCallback callback) override;
void CanCreateWriter(blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) override;
void CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) override;
void CanCreateRewriter(blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) override;
void CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) override;
void CanCreateProofreader(blink::mojom::AIProofreaderCreateOptionsPtr options,
CanCreateProofreaderCallback callback) override;
void CreateProofreader(
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient>
client,
blink::mojom::AIProofreaderCreateOptionsPtr options) override;
void AddModelDownloadProgressObserver(
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
observer_remote) override;
mojo::ReceiverSet<blink::mojom::AIManager> receivers_;
raw_ptr<content::BrowserContext> browser_context_;
content::WeakDocumentPtr rfh_;
mojo::Remote<blink::mojom::AIManager> ai_manager_remote_;
base::CallbackListSubscription ai_handler_changed_subscription_;
base::WeakPtrFactory<ProxyingAIManager> weak_ptr_factory_{this};
};
} // namespace electron
#endif // ELECTRON_SHELL_BROWSER_AI_PROXYING_AI_MANAGER_H_

View File

@@ -264,6 +264,11 @@ void BrowserWindow::BlurWebView() {
web_contents()->GetRenderViewHost()->GetWidget()->Blur();
}
bool BrowserWindow::IsWebViewFocused() {
auto* host_view = web_contents()->GetRenderViewHost()->GetWidget()->GetView();
return host_view && host_view->HasFocus();
}
v8::Local<v8::Value> BrowserWindow::GetWebContents(v8::Isolate* isolate) {
if (web_contents_.IsEmpty())
return v8::Null(isolate);
@@ -327,6 +332,7 @@ void BrowserWindow::BuildPrototype(v8::Isolate* isolate,
gin_helper::ObjectTemplateBuilder(isolate, prototype->PrototypeTemplate())
.SetMethod("focusOnWebView", &BrowserWindow::FocusOnWebView)
.SetMethod("blurWebView", &BrowserWindow::BlurWebView)
.SetMethod("isWebViewFocused", &BrowserWindow::IsWebViewFocused)
.SetProperty("webContents", &BrowserWindow::GetWebContents);
}

View File

@@ -73,6 +73,7 @@ class BrowserWindow : public BaseWindow,
// BrowserWindow APIs.
void FocusOnWebView();
void BlurWebView();
bool IsWebViewFocused();
v8::Local<v8::Value> GetWebContents(v8::Isolate* isolate);
private:

View File

@@ -220,6 +220,12 @@ download::DownloadItem::DownloadState DownloadItem::GetState() const {
return download_item_->GetState();
}
bool DownloadItem::IsDone() const {
if (!CheckAlive())
return false;
return download_item_->IsDone();
}
void DownloadItem::SetSavePath(const base::FilePath& path) {
save_path_ = path;
}
@@ -283,6 +289,7 @@ gin::ObjectTemplateBuilder DownloadItem::GetObjectTemplateBuilder(
.SetMethod("getURL", &DownloadItem::GetURL)
.SetMethod("getURLChain", &DownloadItem::GetURLChain)
.SetMethod("getState", &DownloadItem::GetState)
.SetMethod("isDone", &DownloadItem::IsDone)
.SetMethod("setSavePath", &DownloadItem::SetSavePath)
.SetMethod("getSavePath", &DownloadItem::GetSavePath)
.SetProperty("savePath", &DownloadItem::GetSavePath,

View File

@@ -78,6 +78,7 @@ class DownloadItem final : public gin_helper::DeprecatedWrappable<DownloadItem>,
const GURL& GetURL() const;
v8::Local<v8::Value> GetURLChain() const;
download::DownloadItem::DownloadState GetState() const;
bool IsDone() const;
void SetSaveDialogOptions(const file_dialog::DialogSettings& options);
std::string GetLastModifiedTime() const;
std::string GetETag() const;

View File

@@ -264,6 +264,22 @@ int Menu::GetItemCount() const {
return model_->GetItemCount();
}
int Menu::GetCommandIdAt(int index) const {
return model_->GetCommandIdAt(index);
}
std::u16string Menu::GetLabelAt(int index) const {
return model_->GetLabelAt(index);
}
std::u16string Menu::GetSublabelAt(int index) const {
return model_->GetSecondaryLabelAt(index);
}
std::u16string Menu::GetToolTipAt(int index) const {
return model_->GetToolTipAt(index);
}
std::u16string Menu::GetAcceleratorTextAtForTesting(int index) const {
ui::Accelerator accelerator;
model_->GetAcceleratorAtWithParams(index, true, &accelerator);
@@ -282,6 +298,10 @@ bool Menu::IsVisibleAt(int index) const {
return model_->IsVisibleAt(index);
}
bool Menu::WorksWhenHiddenAt(int index) const {
return model_->WorksWhenHiddenAt(index);
}
void Menu::OnMenuWillClose() {
keep_alive_.Clear();
Emit("menu-will-close");
@@ -305,9 +325,15 @@ void Menu::FillObjectTemplate(v8::Isolate* isolate,
.SetMethod("setRole", &Menu::SetRole)
.SetMethod("setCustomType", &Menu::SetCustomType)
.SetMethod("clear", &Menu::Clear)
.SetMethod("getIndexOfCommandId", &Menu::GetIndexOfCommandId)
.SetMethod("getItemCount", &Menu::GetItemCount)
.SetMethod("getCommandIdAt", &Menu::GetCommandIdAt)
.SetMethod("getLabelAt", &Menu::GetLabelAt)
.SetMethod("getSublabelAt", &Menu::GetSublabelAt)
.SetMethod("getToolTipAt", &Menu::GetToolTipAt)
.SetMethod("isItemCheckedAt", &Menu::IsItemCheckedAt)
.SetMethod("isEnabledAt", &Menu::IsEnabledAt)
.SetMethod("worksWhenHiddenAt", &Menu::WorksWhenHiddenAt)
.SetMethod("isVisibleAt", &Menu::IsVisibleAt)
.SetMethod("popupAt", &Menu::PopupAt)
.SetMethod("closePopupAt", &Menu::ClosePopupAt)

View File

@@ -131,9 +131,14 @@ class Menu : public gin::Wrappable<Menu>,
void Clear();
int GetIndexOfCommandId(int command_id) const;
int GetItemCount() const;
int GetCommandIdAt(int index) const;
std::u16string GetLabelAt(int index) const;
std::u16string GetSublabelAt(int index) const;
std::u16string GetToolTipAt(int index) const;
bool IsItemCheckedAt(int index) const;
bool IsEnabledAt(int index) const;
bool IsVisibleAt(int index) const;
bool WorksWhenHiddenAt(int index) const;
gin_helper::SelfKeepAlive<Menu> keep_alive_{this};
};

View File

@@ -80,12 +80,9 @@ gin::DeprecatedWrapperInfo ServiceWorkerContext::kWrapperInfo = {
ServiceWorkerContext::ServiceWorkerContext(
v8::Isolate* isolate,
ElectronBrowserContext* browser_context)
: service_worker_context_{browser_context->GetDefaultStoragePartition()
->GetServiceWorkerContext()},
browser_context_id_{browser_context->UniqueId()},
storage_partition_config_{
browser_context->GetDefaultStoragePartition()->GetConfig()} {
ElectronBrowserContext* browser_context) {
storage_partition_ = browser_context->GetDefaultStoragePartition();
service_worker_context_ = storage_partition_->GetServiceWorkerContext();
service_worker_context_->AddObserver(this);
}
@@ -96,8 +93,9 @@ ServiceWorkerContext::~ServiceWorkerContext() {
void ServiceWorkerContext::OnRunningStatusChanged(
int64_t version_id,
blink::EmbeddedWorkerStatus running_status) {
if (auto* worker = ServiceWorkerMain::FromVersionID(
browser_context_id_, storage_partition_config_, version_id))
ServiceWorkerMain* worker =
ServiceWorkerMain::FromVersionID(version_id, storage_partition_);
if (worker)
worker->OnRunningStatusChanged(running_status);
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
@@ -135,8 +133,9 @@ void ServiceWorkerContext::OnRegistrationCompleted(const GURL& scope) {
void ServiceWorkerContext::OnVersionRedundant(int64_t version_id,
const GURL& scope) {
if (auto* worker = ServiceWorkerMain::FromVersionID(
browser_context_id_, storage_partition_config_, version_id))
ServiceWorkerMain* worker =
ServiceWorkerMain::FromVersionID(version_id, storage_partition_);
if (worker)
worker->OnVersionRedundant();
}
@@ -207,19 +206,18 @@ v8::Local<v8::Value> ServiceWorkerContext::GetWorkerFromVersionID(
v8::Isolate* isolate,
int64_t version_id) {
return ServiceWorkerMain::From(isolate, service_worker_context_,
browser_context_id_, storage_partition_config_,
version_id)
storage_partition_, version_id)
.ToV8();
}
gin_helper::Handle<ServiceWorkerMain>
ServiceWorkerContext::GetWorkerFromVersionIDIfExists(v8::Isolate* isolate,
int64_t version_id) {
if (auto* worker = ServiceWorkerMain::FromVersionID(
browser_context_id_, storage_partition_config_, version_id))
return gin_helper::CreateHandle(isolate, worker);
return {};
ServiceWorkerMain* worker =
ServiceWorkerMain::FromVersionID(version_id, storage_partition_);
if (!worker)
return gin_helper::Handle<ServiceWorkerMain>();
return gin_helper::CreateHandle(isolate, worker);
}
v8::Local<v8::Promise> ServiceWorkerContext::StartWorkerForScope(

View File

@@ -5,17 +5,18 @@
#ifndef ELECTRON_SHELL_BROWSER_API_ELECTRON_API_SERVICE_WORKER_CONTEXT_H_
#define ELECTRON_SHELL_BROWSER_API_ELECTRON_API_SERVICE_WORKER_CONTEXT_H_
#include <string>
#include "base/memory/raw_ptr.h"
#include "content/public/browser/service_worker_context.h"
#include "content/public/browser/service_worker_context_observer.h"
#include "content/public/browser/storage_partition_config.h"
#include "shell/browser/event_emitter_mixin.h"
#include "shell/common/gin_helper/wrappable.h"
#include "third_party/blink/public/common/service_worker/embedded_worker_status.h"
#include "third_party/blink/public/common/tokens/tokens.h"
namespace content {
class StoragePartition;
}
namespace gin_helper {
template <typename T>
class Handle;
@@ -98,13 +99,9 @@ class ServiceWorkerContext final
raw_ptr<content::ServiceWorkerContext> service_worker_context_;
// A key identifying the owning BrowserContext.
// Used in ServiceWorkerMain lookups.
const std::string browser_context_id_;
// A key identifying a StoragePartition within a BrowserContext.
// Used in ServiceWorkerMain lookups.
const content::StoragePartitionConfig storage_partition_config_;
// Service worker registration and versions are unique to a storage partition.
// Keep a reference to the storage partition to be used for lookups.
raw_ptr<content::StoragePartition> storage_partition_;
base::WeakPtrFactory<ServiceWorkerContext> weak_ptr_factory_{this};
};

View File

@@ -7,8 +7,6 @@
#include <string>
#include <utility>
#include "base/containers/flat_map.h"
#include "base/containers/map_util.h"
#include "base/logging.h"
#include "base/no_destructor.h"
#include "content/browser/service_worker/service_worker_context_wrapper.h" // nogncheck
@@ -29,6 +27,7 @@
#include "shell/common/gin_helper/promise.h"
#include "shell/common/node_includes.h"
#include "shell/common/v8_util.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
namespace {
@@ -59,23 +58,27 @@ std::optional<content::ServiceWorkerVersionBaseInfo> GetLiveVersionInfo(
namespace electron::api {
// ServiceWorkerKey -> ServiceWorkerMain*
auto& GetVersionIdMap() {
using Map = base::flat_map<ServiceWorkerKey, ServiceWorkerMain*>;
static base::NoDestructor<Map> instance;
using VersionIdMap = absl::flat_hash_map<ServiceWorkerKey,
ServiceWorkerMain*,
ServiceWorkerKey::Hasher>;
VersionIdMap& GetVersionIdMap() {
static base::NoDestructor<VersionIdMap> instance;
return *instance;
}
ServiceWorkerMain* FromServiceWorkerKey(const ServiceWorkerKey& key) {
return base::FindPtrOrNull(GetVersionIdMap(), key);
VersionIdMap& version_map = GetVersionIdMap();
auto iter = version_map.find(key);
auto* service_worker = iter == version_map.end() ? nullptr : iter->second;
return service_worker;
}
// static
ServiceWorkerMain* ServiceWorkerMain::FromVersionID(
std::string browser_context_id,
content::StoragePartitionConfig storage_partition_config,
int64_t version_id) {
const ServiceWorkerKey key{std::move(browser_context_id),
std::move(storage_partition_config), version_id};
int64_t version_id,
const content::StoragePartition* storage_partition) {
ServiceWorkerKey key(version_id, storage_partition);
return FromServiceWorkerKey(key);
}
@@ -84,10 +87,8 @@ gin::DeprecatedWrapperInfo ServiceWorkerMain::kWrapperInfo = {
ServiceWorkerMain::ServiceWorkerMain(content::ServiceWorkerContext* sw_context,
int64_t version_id,
ServiceWorkerKey key)
: version_id_{version_id},
key_{std::move(key)},
service_worker_context_{sw_context} {
const ServiceWorkerKey& key)
: version_id_(version_id), key_(key), service_worker_context_(sw_context) {
GetVersionIdMap().emplace(key_, this);
InvalidateVersionInfo();
}
@@ -297,14 +298,12 @@ gin_helper::Handle<ServiceWorkerMain> ServiceWorkerMain::New(
gin_helper::Handle<ServiceWorkerMain> ServiceWorkerMain::From(
v8::Isolate* isolate,
content::ServiceWorkerContext* sw_context,
std::string browser_context_id,
content::StoragePartitionConfig storage_partition_config,
const content::StoragePartition* storage_partition,
int64_t version_id) {
ServiceWorkerKey service_worker_key{std::move(browser_context_id),
std::move(storage_partition_config),
version_id};
ServiceWorkerKey service_worker_key(version_id, storage_partition);
if (auto* service_worker = FromServiceWorkerKey(service_worker_key))
auto* service_worker = FromServiceWorkerKey(service_worker_key);
if (service_worker)
return gin_helper::CreateHandle(isolate, service_worker);
// Ensure ServiceWorkerVersion exists and is not redundant (pending deletion)
@@ -314,8 +313,8 @@ gin_helper::Handle<ServiceWorkerMain> ServiceWorkerMain::From(
}
auto handle = gin_helper::CreateHandle(
isolate, new ServiceWorkerMain{sw_context, version_id,
std::move(service_worker_key)});
isolate,
new ServiceWorkerMain(sw_context, version_id, service_worker_key));
// Prevent garbage collection of worker until it has been deleted internally.
handle->Pin(isolate);

View File

@@ -5,13 +5,13 @@
#ifndef ELECTRON_SHELL_BROWSER_API_ELECTRON_API_SERVICE_WORKER_MAIN_H_
#define ELECTRON_SHELL_BROWSER_API_ELECTRON_API_SERVICE_WORKER_MAIN_H_
#include <compare>
#include <string>
#include "base/memory/raw_ptr.h"
#include "base/process/process.h"
#include "content/public/browser/global_routing_id.h"
#include "content/public/browser/service_worker_context.h"
#include "content/public/browser/service_worker_version_base_info.h"
#include "content/public/browser/storage_partition_config.h"
#include "mojo/public/cpp/bindings/associated_receiver.h"
#include "mojo/public/cpp/bindings/associated_remote.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
@@ -24,6 +24,10 @@
class GURL;
namespace content {
class StoragePartition;
}
namespace gin {
class Arguments;
} // namespace gin
@@ -38,21 +42,41 @@ class Promise;
namespace electron::api {
// Key to uniquely identify a ServiceWorkerMain by its
// BrowserContext ID, the StoragePartition key, and version id.
// Key to uniquely identify a ServiceWorkerMain by its Version ID within the
// associated StoragePartition.
struct ServiceWorkerKey {
std::string browser_context_id;
content::StoragePartitionConfig storage_partition_config;
int64_t version_id;
auto operator<=>(const ServiceWorkerKey&) const = default;
raw_ptr<const content::StoragePartition> storage_partition;
ServiceWorkerKey(int64_t id, const content::StoragePartition* partition)
: version_id(id), storage_partition(partition) {}
bool operator<(const ServiceWorkerKey& other) const {
return std::tie(version_id, storage_partition) <
std::tie(other.version_id, other.storage_partition);
}
bool operator==(const ServiceWorkerKey& other) const {
return version_id == other.version_id &&
storage_partition == other.storage_partition;
}
struct Hasher {
std::size_t operator()(const ServiceWorkerKey& key) const {
return std::hash<const content::StoragePartition*>()(
key.storage_partition) ^
std::hash<int64_t>()(key.version_id);
}
};
};
// Creates a wrapper to align with the lifecycle of the non-public
// content::ServiceWorkerVersion. Object instances are pinned for the lifetime
// of the underlying SW such that registered IPC handlers continue to dispatch.
//
// Instances are uniquely identified by pairing their version ID with the
// BrowserContext and StoragePartition in which they're registered.
// Instances are uniquely identified by pairing their version ID and the
// StoragePartition in which they're registered. In Electron, this is always
// the default StoragePartition for the associated BrowserContext.
class ServiceWorkerMain final
: public gin_helper::DeprecatedWrappable<ServiceWorkerMain>,
public gin_helper::Pinnable<ServiceWorkerMain>,
@@ -64,13 +88,11 @@ class ServiceWorkerMain final
static gin_helper::Handle<ServiceWorkerMain> From(
v8::Isolate* isolate,
content::ServiceWorkerContext* sw_context,
std::string browser_context_id,
content::StoragePartitionConfig storage_partition_config,
const content::StoragePartition* storage_partition,
int64_t version_id);
static ServiceWorkerMain* FromVersionID(
std::string browser_context_id,
content::StoragePartitionConfig storage_partition_config,
int64_t version_id);
int64_t version_id,
const content::StoragePartition* storage_partition);
// gin_helper::Constructible
static void FillObjectTemplate(v8::Isolate*, v8::Local<v8::ObjectTemplate>);
@@ -90,7 +112,7 @@ class ServiceWorkerMain final
protected:
explicit ServiceWorkerMain(content::ServiceWorkerContext* sw_context,
int64_t version_id,
ServiceWorkerKey key);
const ServiceWorkerKey& key);
~ServiceWorkerMain() override;
private:
@@ -124,12 +146,11 @@ class ServiceWorkerMain final
GURL ScopeURL() const;
GURL ScriptURL() const;
// Version ID assigned by the service worker storage.
const int64_t version_id_;
// Version ID unique only to the StoragePartition.
int64_t version_id_;
// Unique identifier pairing the Version ID, BrowserContext, and
// StoragePartition.
const ServiceWorkerKey key_;
// Unique identifier pairing the Version ID and StoragePartition.
ServiceWorkerKey key_;
// Whether the Service Worker version has been destroyed.
bool version_destroyed_ = false;
@@ -142,6 +163,8 @@ class ServiceWorkerMain final
raw_ptr<content::ServiceWorkerContext> service_worker_context_;
mojo::AssociatedRemote<mojom::ElectronRenderer> remote_;
std::unique_ptr<gin_helper::Promise<void>> start_worker_promise_;
};
} // namespace electron::api

View File

@@ -1555,6 +1555,26 @@ v8::Local<v8::Value> Session::ClearData(gin::Arguments* const args) {
return promise_handle;
}
void Session::RegisterLocalAIHandler(gin_helper::ErrorThrower thrower,
v8::Local<v8::Value> val) {
auto* isolate = JavascriptEnvironment::GetIsolate();
gin_helper::Handle<UtilityProcessWrapper> handler;
if (!(val->IsNull() || gin::ConvertFromV8(isolate, val, &handler))) {
thrower.ThrowTypeError("Must pass null or UtilityProcess");
return;
}
auto* prefs = SessionPreferences::FromBrowserContext(browser_context());
DCHECK(prefs);
if (!handler.IsEmpty()) {
prefs->SetLocalAIHandler(handler->GetWeakPtr());
} else {
prefs->SetLocalAIHandler(nullptr);
}
}
#if BUILDFLAG(ENABLE_BUILTIN_SPELLCHECKER)
base::Value Session::GetSpellCheckerLanguages() {
return browser_context_->prefs()
@@ -1841,6 +1861,7 @@ void Session::FillObjectTemplate(v8::Isolate* isolate,
.SetMethod("setCodeCachePath", &Session::SetCodeCachePath)
.SetMethod("clearCodeCaches", &Session::ClearCodeCaches)
.SetMethod("clearData", &Session::ClearData)
.SetMethod("_registerLocalAIHandler", &Session::RegisterLocalAIHandler)
.SetProperty("cookies", &Session::Cookies)
.SetProperty("extensions", &Session::Extensions)
.SetProperty("netLog", &Session::NetLog)

View File

@@ -19,6 +19,7 @@
#include "gin/wrappable.h"
#include "services/network/public/mojom/host_resolver.mojom-forward.h"
#include "services/network/public/mojom/ssl_config.mojom-forward.h"
#include "shell/browser/api/electron_api_utility_process.h"
#include "shell/browser/api/ipc_dispatcher.h"
#include "shell/browser/event_emitter_mixin.h"
#include "shell/browser/net/resolve_proxy_helper.h"
@@ -178,6 +179,8 @@ class Session final : public gin::Wrappable<Session>,
void SetCodeCachePath(gin::Arguments* args);
v8::Local<v8::Promise> ClearCodeCaches(const gin_helper::Dictionary& options);
v8::Local<v8::Value> ClearData(gin::Arguments* args);
void RegisterLocalAIHandler(gin_helper::ErrorThrower thrower,
v8::Local<v8::Value> val);
#if BUILDFLAG(ENABLE_BUILTIN_SPELLCHECKER)
base::Value GetSpellCheckerLanguages();
void SetSpellCheckerLanguages(gin_helper::ErrorThrower thrower,

View File

@@ -47,6 +47,10 @@
#include "base/win/windows_types.h"
#endif
#if BUILDFLAG(ENABLE_PROMPT_API)
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#endif // BUILDFLAG(ENABLE_PROMPT_API)
namespace electron {
namespace {
@@ -454,6 +458,19 @@ UtilityProcessWrapper::CreateURLLoaderFactoryParams() {
return params;
}
#if BUILDFLAG(ENABLE_PROMPT_API)
void UtilityProcessWrapper::BindAIManager(
std::optional<int32_t> web_contents_id,
const url::Origin& security_origin,
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager) {
auto params = node::mojom::BindAIManagerParams::New();
params->web_contents_id = web_contents_id;
params->security_origin = security_origin;
node_service_remote_->BindAIManager(std::move(params), std::move(ai_manager));
}
#endif // BUILDFLAG(ENABLE_PROMPT_API)
// static
raw_ptr<UtilityProcessWrapper> UtilityProcessWrapper::FromProcessId(
base::ProcessId pid) {

View File

@@ -15,6 +15,7 @@
#include "base/memory/weak_ptr.h"
#include "base/process/process_handle.h"
#include "content/public/browser/service_process_host.h"
#include "electron/buildflags/buildflags.h"
#include "mojo/public/cpp/bindings/message.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "shell/browser/event_emitter_mixin.h"
@@ -24,6 +25,10 @@
#include "shell/services/node/public/mojom/node_service.mojom.h"
#include "v8/include/v8-forward.h"
#if BUILDFLAG(ENABLE_PROMPT_API)
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#endif // BUILDFLAG(ENABLE_PROMPT_API)
namespace gin {
class Arguments;
} // namespace gin
@@ -58,6 +63,16 @@ class UtilityProcessWrapper final
static gin_helper::Handle<UtilityProcessWrapper> Create(gin::Arguments* args);
static raw_ptr<UtilityProcessWrapper> FromProcessId(base::ProcessId pid);
#if BUILDFLAG(ENABLE_PROMPT_API)
void BindAIManager(std::optional<int32_t> web_contents_id,
const url::Origin& security_origin,
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager);
#endif // BUILDFLAG(ENABLE_PROMPT_API)
base::WeakPtr<UtilityProcessWrapper> GetWeakPtr() {
return weak_factory_.GetWeakPtr();
}
void Shutdown(uint32_t exit_code);
// gin_helper::Wrappable

View File

@@ -1767,8 +1767,7 @@ bool WebContents::CheckMediaAccessPermission(
content::WebContents::FromRenderFrameHost(render_frame_host);
auto* permission_helper =
WebContentsPermissionHelper::FromWebContents(web_contents);
return permission_helper->CheckMediaAccessPermission(render_frame_host,
security_origin, type);
return permission_helper->CheckMediaAccessPermission(security_origin, type);
}
void WebContents::RequestMediaAccessPermission(
@@ -2451,9 +2450,16 @@ int32_t WebContents::GetProcessID() const {
}
base::ProcessId WebContents::GetOSProcessID() const {
const auto& process =
web_contents()->GetPrimaryMainFrame()->GetProcess()->GetProcess();
return process.IsValid() ? process.Pid() : base::kNullProcessId;
base::ProcessHandle process_handle = web_contents()
->GetPrimaryMainFrame()
->GetProcess()
->GetProcess()
.Handle();
return base::GetProcId(process_handle);
}
bool WebContents::Equal(const WebContents* web_contents) const {
return ID() == web_contents->ID();
}
GURL WebContents::GetURL() const {
@@ -4582,6 +4588,7 @@ void WebContents::FillObjectTemplate(v8::Isolate* isolate,
&WebContents::SetBackgroundThrottling)
.SetMethod("getProcessId", &WebContents::GetProcessID)
.SetMethod("getOSProcessId", &WebContents::GetOSProcessID)
.SetMethod("equal", &WebContents::Equal)
.SetMethod("_loadURL", &WebContents::LoadURL)
.SetMethod("reload", &WebContents::Reload)
.SetMethod("reloadIgnoringCache", &WebContents::ReloadIgnoringCache)

View File

@@ -195,6 +195,7 @@ class WebContents final : public ExclusiveAccessContext,
int32_t GetProcessID() const;
base::ProcessId GetOSProcessID() const;
[[nodiscard]] Type type() const { return type_; }
bool Equal(const WebContents* web_contents) const;
void LoadURL(const GURL& url, const gin_helper::Dictionary& options);
void Reload();
void ReloadIgnoringCache();

View File

@@ -417,13 +417,10 @@ std::string WebFrameMain::FrameToken() const {
base::ProcessId WebFrameMain::OSProcessID() const {
if (!CheckRenderFrame())
return base::kNullProcessId;
const auto& process = render_frame_host()->GetProcess()->GetProcess();
if (!process.IsValid())
return base::kNullProcessId;
return process.Pid();
return -1;
base::ProcessHandle process_handle =
render_frame_host()->GetProcess()->GetProcess().Handle();
return base::GetProcId(process_handle);
}
int32_t WebFrameMain::ProcessID() const {

View File

@@ -230,12 +230,20 @@
#include "ui/webui/resources/cr_components/help_bubble/help_bubble.mojom.h" // nogncheck
#endif
#if BUILDFLAG(ENABLE_PROMPT_API)
#include "shell/browser/ai/proxying_ai_manager.h"
#endif // BUILDFLAG(ENABLE_PROMPT_API)
using content::BrowserThread;
namespace electron {
namespace {
#if BUILDFLAG(ENABLE_PROMPT_API)
const char kAIManagerUserDataKey[] = "ai_manager";
#endif // BUILDFLAG(ENABLE_PROMPT_API)
ElectronBrowserClient* g_browser_client = nullptr;
base::NoDestructor<std::string> g_io_thread_application_locale;
@@ -1580,6 +1588,26 @@ void ElectronBrowserClient::
#endif
}
#if BUILDFLAG(ENABLE_PROMPT_API)
// Refs
// https://source.chromium.org/chromium/chromium/src/+/main:chrome/browser/chrome_content_browser_client.cc;l=8724-8737;drc=74754be9d4550a487df006a51a33318245d37301
void ElectronBrowserClient::BindAIManager(
content::BrowserContext* browser_context,
base::SupportsUserData* context_user_data,
content::RenderFrameHost* rfh,
mojo::PendingReceiver<blink::mojom::AIManager> receiver) {
if (!context_user_data->GetUserData(kAIManagerUserDataKey)) {
context_user_data->SetUserData(
kAIManagerUserDataKey,
std::make_unique<ProxyingAIManager>(browser_context, rfh));
}
ProxyingAIManager* ai_manager = static_cast<ProxyingAIManager*>(
context_user_data->GetUserData(kAIManagerUserDataKey));
ai_manager->AddReceiver(std::move(receiver));
}
#endif // BUILDFLAG(ENABLE_PROMPT_API)
std::string ElectronBrowserClient::GetApplicationLocale() {
return BrowserThread::CurrentlyOn(BrowserThread::IO)
? *g_io_thread_application_locale

View File

@@ -274,6 +274,14 @@ class ElectronBrowserClient : public content::ContentBrowserClient,
const content::ServiceWorkerVersionBaseInfo& service_worker_version_info,
blink::AssociatedInterfaceRegistry& associated_registry) override;
#if BUILDFLAG(ENABLE_PROMPT_API)
void BindAIManager(
content::BrowserContext* browser_context,
base::SupportsUserData* context_user_data,
content::RenderFrameHost* rfh,
mojo::PendingReceiver<blink::mojom::AIManager> receiver) override;
#endif // BUILDFLAG(ENABLE_PROMPT_API)
bool HandleExternalProtocol(
const GURL& url,
content::WebContents::Getter web_contents_getter,

View File

@@ -20,7 +20,6 @@
#include "base/time/time.h"
#include "content/public/browser/browser_context.h"
#include "extensions/browser/extension_file_task_runner.h"
#include "extensions/browser/extension_pref_names.h"
#include "extensions/browser/extension_prefs.h"
#include "extensions/browser/extension_registry.h"
#include "extensions/browser/pref_names.h"
@@ -28,7 +27,6 @@
#include "extensions/common/error_utils.h"
#include "extensions/common/file_util.h"
#include "extensions/common/manifest_constants.h"
#include "extensions/common/manifest_handlers/background_info.h"
namespace extensions {
@@ -145,19 +143,6 @@ void ElectronExtensionLoader::FinishExtensionLoad(
std::pair<scoped_refptr<const Extension>, std::string> result) {
scoped_refptr<const Extension> extension = result.first;
if (extension) {
ExtensionPrefs* extension_prefs = ExtensionPrefs::Get(browser_context_);
if (BackgroundInfo::IsServiceWorkerBased(extension.get())) {
// Tell Chromium that it needs to start the extension's service worker.
// Chromium usually does this only when an extension is first installed
// because Chrome will restart the service worker when the browser
// relaunches. In Electron, we make a fresh install on every app start,
// so we need to run the fresh install logic again.
extension_prefs->UpdateExtensionPref(
extension.get()->id(), extensions::kPrefHasStartedServiceWorker,
base::Value(false));
}
extension_registrar_->AddExtension(extension);
// Write extension install time to ExtensionPrefs.
@@ -167,6 +152,7 @@ void ElectronExtensionLoader::FinishExtensionLoad(
// Implementation for writing the pref was based on
// PreferenceAPIBase::SetExtensionControlledPref.
{
ExtensionPrefs* extension_prefs = ExtensionPrefs::Get(browser_context_);
ExtensionPrefs::ScopedDictionaryUpdate update(
extension_prefs, extension.get()->id(),
extensions::pref_names::kPrefPreferences);

View File

@@ -697,11 +697,7 @@ void FileSystemAccessPermissionContext::ConfirmSensitiveEntryAccess(
content::GlobalRenderFrameHostId frame_id,
base::OnceCallback<void(SensitiveEntryResult)> callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto [it, inserted] = callback_map_.try_emplace(path_info.path);
it->second.push_back(std::move(callback));
if (!inserted)
return;
callback_map_.try_emplace(path_info.path, std::move(callback));
auto after_blocklist_check_callback = base::BindOnce(
&FileSystemAccessPermissionContext::DidCheckPathAgainstBlocklist,
@@ -773,11 +769,8 @@ void FileSystemAccessPermissionContext::PerformAfterWriteChecks(
void FileSystemAccessPermissionContext::RunRestrictedPathCallback(
const base::FilePath& file_path,
SensitiveEntryResult result) {
if (auto val = callback_map_.extract(file_path)) {
for (auto& callback : val.mapped()) {
std::move(callback).Run(result);
}
}
if (auto val = callback_map_.extract(file_path))
std::move(val.mapped()).Run(result);
}
void FileSystemAccessPermissionContext::OnRestrictedPathResult(

View File

@@ -196,8 +196,7 @@ class FileSystemAccessPermissionContext
std::map<url::Origin, base::DictValue> id_pathinfo_map_;
std::map<base::FilePath,
std::vector<base::OnceCallback<void(SensitiveEntryResult)>>>
std::map<base::FilePath, base::OnceCallback<void(SensitiveEntryResult)>>
callback_map_;
std::unique_ptr<ChromeFileSystemAccessPermissionContext::BlockPathRules>

View File

@@ -86,7 +86,6 @@ JavascriptEnvironment::~JavascriptEnvironment() {
// Otherwise cppgc::internal::Sweeper::Start will try to request a task runner
// from the NodePlatform with an already unregistered isolate.
locker_.reset();
DCHECK(!microtasks_runner_);
isolate_holder_.reset();
platform_->UnregisterIsolate(isolate_);
@@ -160,7 +159,6 @@ void JavascriptEnvironment::DestroyMicrotasksRunner() {
gin_helper::CleanedUpAtExit::DoCleanup();
}
base::CurrentThread::Get()->RemoveTaskObserver(microtasks_runner_.get());
microtasks_runner_.reset();
}
} // namespace electron

View File

@@ -4,16 +4,12 @@
#include "shell/browser/notifications/linux/libnotify_notification.h"
#include <dlfcn.h>
#include <array>
#include <string>
#include "base/containers/flat_set.h"
#include "base/files/file_enumerator.h"
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/nix/xdg_util.h"
#include "base/no_destructor.h"
#include "base/process/process_handle.h"
#include "base/strings/utf_string_conversions.h"
@@ -54,9 +50,6 @@ bool NotifierSupportsActions() {
return HasCapability("actions");
}
using GetActivationTokenFunc = const char* (*)(NotifyNotification*);
GetActivationTokenFunc g_get_activation_token = nullptr;
void log_and_clear_error(GError* error, const char* context) {
LOG(ERROR) << context << ": domain=" << error->domain
<< " code=" << error->code << " message=\"" << error->message
@@ -68,40 +61,18 @@ void log_and_clear_error(GError* error, const char* context) {
// static
bool LibnotifyNotification::Initialize() {
constexpr std::array kLibnotifySonames = {
"libnotify.so.4",
"libnotify.so.5",
"libnotify.so.1",
"libnotify.so",
};
const char* loaded_soname = nullptr;
for (const char* soname : kLibnotifySonames) {
if (GetLibNotifyLoader().Load(soname)) {
loaded_soname = soname;
break;
}
}
if (!loaded_soname) {
if (!GetLibNotifyLoader().Load("libnotify.so.4") && // most common one
!GetLibNotifyLoader().Load("libnotify.so.5") &&
!GetLibNotifyLoader().Load("libnotify.so.1") &&
!GetLibNotifyLoader().Load("libnotify.so")) {
LOG(WARNING) << "Unable to find libnotify; notifications disabled";
return false;
}
if (!GetLibNotifyLoader().notify_is_initted() &&
!GetLibNotifyLoader().notify_init(GetApplicationName().c_str())) {
LOG(WARNING) << "Unable to initialize libnotify; notifications disabled";
return false;
}
// Safe to cache the symbol after dlclose(handle) because libnotify remains
// loaded via GetLibNotifyLoader() for the process lifetime.
if (void* handle = dlopen(loaded_soname, RTLD_LAZY)) {
g_get_activation_token = reinterpret_cast<GetActivationTokenFunc>(
dlsym(handle, "notify_notification_get_activation_token"));
dlclose(handle);
}
return true;
}
@@ -221,14 +192,6 @@ void LibnotifyNotification::OnNotificationView(NotifyNotification* notification,
gpointer user_data) {
LibnotifyNotification* that = static_cast<LibnotifyNotification*>(user_data);
DCHECK(that);
if (g_get_activation_token) {
const char* token = g_get_activation_token(notification);
if (token && *token) {
base::nix::SetActivationToken(std::string(token));
}
}
that->NotificationClicked();
}

View File

@@ -51,7 +51,8 @@ bool ElectronSerialDelegate::CanRequestPortPermission(
auto* web_contents = content::WebContents::FromRenderFrameHost(frame);
auto* permission_helper =
WebContentsPermissionHelper::FromWebContents(web_contents);
return permission_helper->CheckSerialAccessPermission(frame);
return permission_helper->CheckSerialAccessPermission(
frame->GetLastCommittedOrigin());
}
bool ElectronSerialDelegate::HasPortPermission(

View File

@@ -30,6 +30,11 @@ SessionPreferences* SessionPreferences::FromBrowserContext(
return static_cast<SessionPreferences*>(context->GetUserData(&kLocatorKey));
}
base::CallbackListSubscription SessionPreferences::AddAIHandlerChangedCallback(
base::RepeatingClosure callback) {
return ai_handler_changed_callbacks_.Add(std::move(callback));
}
bool SessionPreferences::HasServiceWorkerPreloadScript() {
const auto& preloads = preload_scripts();
auto it = std::find_if(

View File

@@ -7,8 +7,11 @@
#include <vector>
#include "base/callback_list.h"
#include "base/files/file_path.h"
#include "base/memory/weak_ptr.h"
#include "base/supports_user_data.h"
#include "shell/browser/api/electron_api_utility_process.h"
#include "shell/browser/preload_script.h"
namespace content {
@@ -17,6 +20,10 @@ class BrowserContext;
namespace electron {
namespace api {
class UtilityProcessWrapper;
}
class SessionPreferences : public base::SupportsUserData::Data {
public:
static SessionPreferences* FromBrowserContext(
@@ -30,6 +37,18 @@ class SessionPreferences : public base::SupportsUserData::Data {
bool HasServiceWorkerPreloadScript();
const base::WeakPtr<api::UtilityProcessWrapper>& GetLocalAIHandler() const {
return local_ai_handler_;
}
void SetLocalAIHandler(base::WeakPtr<api::UtilityProcessWrapper> handler) {
local_ai_handler_ = handler;
ai_handler_changed_callbacks_.Notify();
}
base::CallbackListSubscription AddAIHandlerChangedCallback(
base::RepeatingClosure callback);
private:
SessionPreferences();
@@ -37,6 +56,8 @@ class SessionPreferences : public base::SupportsUserData::Data {
static int kLocatorKey;
std::vector<PreloadScript> preload_scripts_;
base::WeakPtr<api::UtilityProcessWrapper> local_ai_handler_;
base::RepeatingClosureList ai_handler_changed_callbacks_;
};
} // namespace electron

View File

@@ -12,7 +12,6 @@
#include <utility>
#include "base/base64.h"
#include "base/containers/fixed_flat_set.h"
#include "base/containers/span.h"
#include "base/dcheck_is_on.h"
#include "base/memory/raw_ptr.h"
@@ -161,13 +160,6 @@ void OnOpenItemComplete(const base::FilePath& path, const std::string& result) {
constexpr base::TimeDelta kInitialBackoffDelay = base::Milliseconds(250);
constexpr base::TimeDelta kMaxBackoffDelay = base::Seconds(10);
constexpr auto kValidDockStates = base::MakeFixedFlatSet<std::string_view>(
{"bottom", "left", "right", "undocked"});
bool IsValidDockState(const std::string& state) {
return kValidDockStates.contains(state);
}
} // namespace
class InspectableWebContents::NetworkResourceLoader
@@ -402,7 +394,7 @@ void InspectableWebContents::SetDockState(const std::string& state) {
can_dock_ = false;
} else {
can_dock_ = true;
dock_state_ = IsValidDockState(state) ? state : "right";
dock_state_ = state;
}
}
@@ -567,13 +559,7 @@ void InspectableWebContents::LoadCompleted() {
pref_service_->GetDict(kDevToolsPreferences);
const std::string* current_dock_state =
prefs.FindString("currentDockState");
if (current_dock_state) {
std::string sanitized;
base::RemoveChars(*current_dock_state, "\"", &sanitized);
dock_state_ = IsValidDockState(sanitized) ? sanitized : "right";
} else {
dock_state_ = "right";
}
base::RemoveChars(*current_dock_state, "\"", &dock_state_);
}
#if BUILDFLAG(IS_WIN) || BUILDFLAG(IS_LINUX)
auto* api_web_contents = api::WebContents::From(GetWebContents());

View File

@@ -111,7 +111,7 @@ void AutofillPopupView::Show() {
auto* host = popup_->frame_host_->GetRenderViewHost()->GetWidget();
host->AddKeyPressEventCallback(keypress_callback_);
GetViewAccessibility().NotifyEvent(ax::mojom::Event::kMenuStart, true);
NotifyAccessibilityEventDeprecated(ax::mojom::Event::kMenuStart, true);
}
void AutofillPopupView::Hide() {
@@ -122,7 +122,7 @@ void AutofillPopupView::Hide() {
}
RemoveObserver();
GetViewAccessibility().NotifyEvent(ax::mojom::Event::kMenuEnd, true);
NotifyAccessibilityEventDeprecated(ax::mojom::Event::kMenuEnd, true);
if (GetWidget()) {
GetWidget()->Close();
@@ -165,7 +165,7 @@ void AutofillPopupView::OnSelectedRowChanged(
int selected = current_row_selection.value_or(-1);
if (selected == -1 || static_cast<size_t>(selected) >= children().size())
return;
children().at(selected)->GetViewAccessibility().NotifyEvent(
children().at(selected)->NotifyAccessibilityEventDeprecated(
ax::mojom::Event::kSelection, true);
}
}

View File

@@ -163,7 +163,7 @@ int ClientFrameViewLinux::ResizingBorderHitTest(const gfx::Point& point) {
gfx::Rect ClientFrameViewLinux::GetBoundsForClientView() const {
gfx::Rect client_bounds = bounds();
if (!frame_->IsFullscreen()) {
client_bounds.Inset(linux_frame_layout_->FrameBorderInsets(false));
client_bounds.Inset(RestoredFrameBorderInsets());
client_bounds.Inset(
gfx::Insets::TLBR(GetTitlebarBounds().height(), 0, 0, 0));
}
@@ -236,21 +236,6 @@ void ClientFrameViewLinux::Layout(PassKey) {
}
void ClientFrameViewLinux::OnPaint(gfx::Canvas* canvas) {
if (frame_->IsFullscreen()) {
return;
}
if (frame_->IsMaximized()) {
// Some GTK themes (Breeze) still render shadow/border assets when
// maximized, and we don't need a border when maximized anyway. Chromium
// switches on this too: OpaqueBrowserFrameView::PaintMaximizedFrameBorder.
PaintMaximizedFrameBorder(canvas);
} else {
PaintRestoredFrameBorder(canvas);
}
}
void ClientFrameViewLinux::PaintRestoredFrameBorder(gfx::Canvas* canvas) {
if (auto* frame_provider = linux_frame_layout_->GetFrameProvider()) {
frame_provider->PaintWindowFrame(
canvas, GetLocalBounds(), GetTitlebarBounds().bottom(),
@@ -258,18 +243,6 @@ void ClientFrameViewLinux::PaintRestoredFrameBorder(gfx::Canvas* canvas) {
}
}
void ClientFrameViewLinux::PaintMaximizedFrameBorder(gfx::Canvas* canvas) {
ui::NativeTheme::FrameTopAreaExtraParams frame_top_area;
frame_top_area.use_custom_frame = true;
frame_top_area.is_active = ShouldPaintAsActive();
frame_top_area.default_background_color = SK_ColorTRANSPARENT;
ui::NativeTheme::ExtraParams params(frame_top_area);
GetNativeTheme()->Paint(
canvas->sk_canvas(), GetColorProvider(), ui::NativeTheme::kFrameTopArea,
ui::NativeTheme::kNormal,
gfx::Rect(0, 0, width(), GetTitlebarBounds().bottom()), params);
}
void ClientFrameViewLinux::PaintAsActiveChanged() {
UpdateThemeValues();
}
@@ -278,15 +251,23 @@ void ClientFrameViewLinux::UpdateThemeValues() {
gtk::GtkCssContext window_context =
gtk::AppendCssNodeToStyleContext({}, "window.background.csd");
gtk::GtkCssContext headerbar_context = gtk::AppendCssNodeToStyleContext(
window_context, "headerbar.default-decoration.titlebar");
{}, "headerbar.default-decoration.titlebar");
gtk::GtkCssContext title_context =
gtk::AppendCssNodeToStyleContext(headerbar_context, "label.title");
gtk::GtkCssContext button_context = gtk::AppendCssNodeToStyleContext(
headerbar_context, "button.image-button");
gtk_style_context_set_parent(headerbar_context, window_context);
gtk_style_context_set_parent(title_context, headerbar_context);
gtk_style_context_set_parent(button_context, headerbar_context);
// ShouldPaintAsActive asks the widget, so assume active if the widget is not
// set yet.
if (GetWidget() != nullptr && !ShouldPaintAsActive()) {
gtk_style_context_set_state(window_context, GTK_STATE_FLAG_BACKDROP);
gtk_style_context_set_state(headerbar_context, GTK_STATE_FLAG_BACKDROP);
gtk_style_context_set_state(title_context, GTK_STATE_FLAG_BACKDROP);
gtk_style_context_set_state(button_context, GTK_STATE_FLAG_BACKDROP);
}
theme_values_.window_border_radius =
@@ -300,6 +281,10 @@ void ClientFrameViewLinux::UpdateThemeValues() {
theme_values_.title_color = gtk::GtkStyleContextGetColor(title_context);
theme_values_.title_padding = gtk::GtkStyleContextGetPadding(title_context);
gtk::GtkStyleContextGet(button_context, "min-height",
&theme_values_.button_min_size, nullptr);
theme_values_.button_padding = gtk::GtkStyleContextGetPadding(button_context);
title_->SetEnabledColor(theme_values_.title_color);
InvalidateLayout();
@@ -314,9 +299,8 @@ ClientFrameViewLinux::GetButtonTypeToSkip() const {
}
void ClientFrameViewLinux::UpdateButtonImages() {
int top_area_height = theme_values_.titlebar_min_height +
theme_values_.titlebar_padding.height();
nav_button_provider_->RedrawImages(top_area_height, frame_->IsMaximized(),
nav_button_provider_->RedrawImages(theme_values_.button_min_size,
frame_->IsMaximized(),
ShouldPaintAsActive());
ui::NavButtonProvider::FrameButtonDisplayType skip_type =
@@ -384,14 +368,7 @@ void ClientFrameViewLinux::LayoutButtonsOnSide(
button->button->SetVisible(true);
// CSS min-size/height/width is not enough to determine the actual size of
// the buttons, so we sample the rendered image. See Chromium's
// BrowserFrameViewLinuxNative::MaybeUpdateCachedFrameButtonImages.
int button_width =
nav_button_provider_
->GetImage(button->type,
ui::NavButtonProvider::ButtonState::kNormal)
.width();
int button_width = theme_values_.button_min_size;
int next_button_offset =
button_width + nav_button_provider_->GetInterNavButtonSpacing();
@@ -427,7 +404,7 @@ gfx::Rect ClientFrameViewLinux::GetTitlebarBounds() const {
std::max(font_height, theme_values_.titlebar_min_height) +
GetTitlebarContentInsets().height();
gfx::Insets decoration_insets = linux_frame_layout_->FrameBorderInsets(false);
gfx::Insets decoration_insets = RestoredFrameBorderInsets();
// We add the inset height here, so the .Inset() that follows won't reduce it
// to be too small.

View File

@@ -91,11 +91,12 @@ class ClientFrameViewLinux : public FramelessView,
SkColor title_color;
gfx::Insets title_padding;
int button_min_size;
gfx::Insets button_padding;
};
void PaintAsActiveChanged();
void PaintRestoredFrameBorder(gfx::Canvas* canvas);
void PaintMaximizedFrameBorder(gfx::Canvas* canvas);
void UpdateThemeValues();

View File

@@ -83,12 +83,6 @@ gfx::Insets LinuxFrameLayout::RestoredFrameBorderInsets() const {
return gfx::Insets();
}
gfx::Insets LinuxFrameLayout::FrameBorderInsets(bool restored) const {
return !restored && (window_->IsMaximized() || window_->IsFullscreen())
? gfx::Insets()
: RestoredFrameBorderInsets();
}
gfx::Insets LinuxFrameLayout::GetInputInsets() const {
return gfx::Insets(kResizeInsideBoundsSize);
}
@@ -112,7 +106,7 @@ void LinuxFrameLayout::set_tiled(bool tiled) {
gfx::Rect LinuxFrameLayout::GetWindowBounds() const {
gfx::Rect bounds = window_->widget()->GetWindowBoundsInScreen();
bounds.Inset(FrameBorderInsets(false));
bounds.Inset(RestoredFrameBorderInsets());
return bounds;
}

View File

@@ -44,9 +44,6 @@ class LinuxFrameLayout {
CSDStyle csd_style);
// Insets from the transparent widget border to the opaque part of the window.
// Returns empty insets when maximized or fullscreen unless |restored| is
// true. Matches Chromium's OpaqueBrowserFrameViewLayout::FrameBorderInsets.
gfx::Insets FrameBorderInsets(bool restored) const;
virtual gfx::Insets RestoredFrameBorderInsets() const;
// Insets for parts of the surface that should be counted for user input.
virtual gfx::Insets GetInputInsets() const;

View File

@@ -210,7 +210,7 @@ void OpaqueFrameView::OnPaint(gfx::Canvas* canvas) {
return;
const bool active = ShouldPaintAsActive();
const gfx::Insets border = FrameBorderInsets(false);
const gfx::Insets border = RestoredFrameBorderInsets();
const bool showing_shadow = linux_frame_layout_->IsShowingShadow();
gfx::RectF bounds_dip(GetLocalBounds());
if (showing_shadow) {
@@ -342,7 +342,9 @@ views::Button* OpaqueFrameView::CreateButton(
}
gfx::Insets OpaqueFrameView::FrameBorderInsets(bool restored) const {
return linux_frame_layout_->FrameBorderInsets(restored);
return !restored && IsFrameCondensed()
? gfx::Insets()
: linux_frame_layout_->RestoredFrameBorderInsets();
}
int OpaqueFrameView::FrameTopBorderThickness(bool restored) const {

View File

@@ -101,13 +101,13 @@ base::DictValue BuildTargetDescriptor(
int process_id,
int routing_id,
ui::AXMode accessibility_mode,
base::ProcessId pid = base::kNullProcessId) {
base::ProcessHandle handle = base::kNullProcessHandle) {
base::DictValue target_data;
target_data.Set(kProcessIdField, process_id);
target_data.Set(kRoutingIdField, routing_id);
target_data.Set(kUrlField, url.spec());
target_data.Set(kNameField, base::EscapeForHTML(name));
target_data.Set(kPidField, static_cast<int>(pid));
target_data.Set(kPidField, static_cast<int>(base::GetProcId(handle)));
target_data.Set(kFaviconUrlField, favicon_url.spec());
target_data.Set(kAccessibilityModeField,
static_cast<int>(accessibility_mode.flags()));
@@ -138,12 +138,9 @@ base::DictValue BuildTargetDescriptor(content::RenderViewHost* rvh) {
accessibility_mode = web_contents->GetAccessibilityMode();
}
const auto& process = rvh->GetProcess()->GetProcess();
const auto pid = process.IsValid() ? process.Pid() : base::kNullProcessId;
return BuildTargetDescriptor(url, title, favicon_url,
rvh->GetProcess()->GetDeprecatedID(),
rvh->GetRoutingID(), accessibility_mode, pid);
rvh->GetRoutingID(), accessibility_mode);
}
base::DictValue BuildTargetDescriptor(electron::NativeWindow* window) {

View File

@@ -228,14 +228,14 @@ void WebContentsPermissionHelper::RequestPermission(
}
bool WebContentsPermissionHelper::CheckPermission(
content::RenderFrameHost* requesting_frame,
blink::PermissionType permission,
base::DictValue details) const {
auto* rfh = web_contents_->GetPrimaryMainFrame();
auto* permission_manager = static_cast<ElectronPermissionManager*>(
web_contents_->GetBrowserContext()->GetPermissionControllerDelegate());
auto origin = requesting_frame->GetLastCommittedOrigin().GetURL();
return permission_manager->CheckPermissionWithDetails(
permission, requesting_frame, origin, std::move(details));
auto origin = web_contents_->GetLastCommittedURL();
return permission_manager->CheckPermissionWithDetails(permission, rfh, origin,
std::move(details));
}
void WebContentsPermissionHelper::RequestFullscreenPermission(
@@ -313,7 +313,6 @@ void WebContentsPermissionHelper::RequestOpenExternalPermission(
}
bool WebContentsPermissionHelper::CheckMediaAccessPermission(
content::RenderFrameHost* requesting_frame,
const url::Origin& security_origin,
blink::mojom::MediaStreamType type) const {
base::DictValue details;
@@ -322,16 +321,14 @@ bool WebContentsPermissionHelper::CheckMediaAccessPermission(
auto blink_type = type == blink::mojom::MediaStreamType::DEVICE_AUDIO_CAPTURE
? blink::PermissionType::AUDIO_CAPTURE
: blink::PermissionType::VIDEO_CAPTURE;
return CheckPermission(requesting_frame, blink_type, std::move(details));
return CheckPermission(blink_type, std::move(details));
}
bool WebContentsPermissionHelper::CheckSerialAccessPermission(
content::RenderFrameHost* requesting_frame) const {
const url::Origin& embedding_origin) const {
base::DictValue details;
details.Set("securityOrigin",
requesting_frame->GetLastCommittedOrigin().GetURL().spec());
return CheckPermission(requesting_frame, blink::PermissionType::SERIAL,
std::move(details));
details.Set("securityOrigin", embedding_origin.GetURL().spec());
return CheckPermission(blink::PermissionType::SERIAL, std::move(details));
}
WEB_CONTENTS_USER_DATA_KEY_IMPL(WebContentsPermissionHelper);

View File

@@ -47,11 +47,9 @@ class WebContentsPermissionHelper
const GURL& url);
// Synchronous Checks
bool CheckMediaAccessPermission(content::RenderFrameHost* requesting_frame,
const url::Origin& security_origin,
bool CheckMediaAccessPermission(const url::Origin& security_origin,
blink::mojom::MediaStreamType type) const;
bool CheckSerialAccessPermission(
content::RenderFrameHost* requesting_frame) const;
bool CheckSerialAccessPermission(const url::Origin& embedding_origin) const;
private:
explicit WebContentsPermissionHelper(content::WebContents* web_contents);
@@ -63,8 +61,7 @@ class WebContentsPermissionHelper
bool user_gesture = false,
base::DictValue details = {});
bool CheckPermission(content::RenderFrameHost* requesting_frame,
blink::PermissionType permission,
bool CheckPermission(blink::PermissionType permission,
base::DictValue details) const;
// TODO(clavin): refactor to use the WebContents provided by the

View File

@@ -66,6 +66,10 @@ void SetHiddenValue(v8::Isolate* isolate,
object->SetPrivate(context, privateKey, value);
}
int32_t GetObjectHash(v8::Local<v8::Object> object) {
return object->GetIdentityHash();
}
void TakeHeapSnapshot(v8::Isolate* isolate) {
isolate->GetHeapProfiler()->TakeHeapSnapshot();
}
@@ -99,6 +103,7 @@ void Initialize(v8::Local<v8::Object> exports,
gin_helper::Dictionary dict{isolate, exports};
dict.SetMethod("getHiddenValue", &GetHiddenValue);
dict.SetMethod("setHiddenValue", &SetHiddenValue);
dict.SetMethod("getObjectHash", &GetObjectHash);
dict.SetMethod("takeHeapSnapshot", &TakeHeapSnapshot);
dict.SetMethod("requestGarbageCollectionForTesting",
&RequestGarbageCollectionForTesting);

View File

@@ -25,6 +25,10 @@ bool IsPrintingEnabled() {
return BUILDFLAG(ENABLE_PRINTING);
}
bool IsPromptAPIEnabled() {
return BUILDFLAG(ENABLE_PROMPT_API);
}
bool IsExtensionsEnabled() {
return BUILDFLAG(ENABLE_ELECTRON_EXTENSIONS);
}
@@ -48,6 +52,7 @@ void Initialize(v8::Local<v8::Object> exports,
dict.SetMethod("isFakeLocationProviderEnabled",
&IsFakeLocationProviderEnabled);
dict.SetMethod("isPrintingEnabled", &IsPrintingEnabled);
dict.SetMethod("isPromptAPIEnabled", &IsPromptAPIEnabled);
dict.SetMethod("isComponentBuild", &IsComponentBuild);
dict.SetMethod("isExtensionsEnabled", &IsExtensionsEnabled);
}

View File

@@ -56,6 +56,16 @@ v8::Local<v8::Value> CustomEmit(v8::Isolate* isolate,
converted_args));
}
template <typename... Args>
v8::Local<v8::Value> CallMethod(v8::Isolate* isolate,
v8::Local<v8::Object> object,
const char* method_name,
Args&&... args) {
v8::EscapableHandleScope scope(isolate);
return scope.Escape(
CustomEmit(isolate, object, method_name, std::forward<Args>(args)...));
}
template <typename T, typename... Args>
v8::Local<v8::Value> CallMethod(v8::Isolate* isolate,
gin_helper::DeprecatedWrappable<T>* object,

View File

@@ -4,7 +4,6 @@
#include "shell/common/gin_helper/wrappable.h"
#include "base/task/sequenced_task_runner.h"
#include "gin/object_template_builder.h"
#include "gin/public/isolate_holder.h"
#include "shell/common/gin_helper/dictionary.h"
@@ -91,22 +90,7 @@ void WrappableBase::SecondWeakCallback(
if (gin::IsolateHolder::DestroyedMicrotasksRunner()) {
return;
}
// Defer destruction to a posted task. V8's second-pass weak callbacks run
// inside a DisallowJavascriptExecutionScope (they may touch the V8 API but
// must not invoke JS). Several Electron Wrappables (e.g. WebContents) emit
// JS events from their destructors, so deleting synchronously here can
// crash with "Invoke in DisallowJavascriptExecutionScope" — see
// https://github.com/electron/electron/issues/47420. Posting via the
// current sequence's task runner ensures the destructor runs once V8 has
// left the GC scope. If no task runner is available (e.g. early/late in
// process lifetime), fall back to synchronous deletion.
auto* wrappable = static_cast<WrappableBase*>(data.GetInternalField(0));
if (base::SequencedTaskRunner::HasCurrentDefault()) {
base::SequencedTaskRunner::GetCurrentDefault()->DeleteSoon(FROM_HERE,
wrappable);
} else {
delete wrappable;
}
delete static_cast<WrappableBase*>(data.GetInternalField(0));
}
DeprecatedWrappableBase::DeprecatedWrappableBase() = default;
@@ -142,19 +126,9 @@ void DeprecatedWrappableBase::SecondWeakCallback(
const v8::WeakCallbackInfo<DeprecatedWrappableBase>& data) {
if (gin::IsolateHolder::DestroyedMicrotasksRunner())
return;
// See WrappableBase::SecondWeakCallback for why deletion is posted: V8's
// second-pass weak callbacks run inside a DisallowJavascriptExecutionScope,
// and several Wrappables emit JS events from their destructors.
// https://github.com/electron/electron/issues/47420
DeprecatedWrappableBase* wrappable = data.GetParameter();
if (!wrappable)
return;
if (base::SequencedTaskRunner::HasCurrentDefault()) {
base::SequencedTaskRunner::GetCurrentDefault()->DeleteSoon(FROM_HERE,
wrappable);
} else {
if (wrappable)
delete wrappable;
}
}
v8::MaybeLocal<v8::Object> DeprecatedWrappableBase::GetWrapperImpl(

View File

@@ -6,7 +6,6 @@
#define ELECTRON_SHELL_COMMON_GIN_HELPER_WRAPPABLE_BASE_H_
#include "base/memory/raw_ptr.h"
#include "base/task/sequenced_task_runner_helpers.h"
#include "v8/include/v8-forward.h"
namespace gin {
@@ -76,11 +75,6 @@ class DeprecatedWrappableBase {
DeprecatedWrappableBase();
virtual ~DeprecatedWrappableBase();
// SecondWeakCallback posts destruction via DeleteSoon so that destructors
// (which may emit JS events) run outside V8's GC scope. DeleteSoon needs
// access to the protected destructor.
friend class base::DeleteHelper<DeprecatedWrappableBase>;
// Overrides of this method should be declared final and not overridden again.
virtual gin::ObjectTemplateBuilder GetObjectTemplateBuilder(
v8::Isolate* isolate);

View File

@@ -116,6 +116,7 @@
V(electron_browser_event_emitter) \
V(electron_browser_system_preferences) \
V(electron_common_net) \
V(electron_utility_local_ai_handler) \
V(electron_utility_parent_port)
#define ELECTRON_TESTING_BINDINGS(V) V(electron_common_testing)

View File

@@ -151,6 +151,22 @@ node::Environment* CreateEnvironment(v8::Isolate* isolate,
return env;
}
v8::Local<v8::Object> CreateAbortController(v8::Isolate* isolate) {
auto context = isolate->GetCurrentContext();
auto global_object = context->Global();
auto value =
global_object->Get(context, gin::StringToV8(isolate, "AbortController"))
.ToLocalChecked();
DCHECK(!value.IsEmpty() && value->IsObject());
DCHECK(value->IsFunction());
auto constructor = value.As<v8::Function>();
auto instance =
constructor->NewInstance(context, 0, nullptr).ToLocalChecked();
return instance;
}
ExplicitMicrotasksScope::ExplicitMicrotasksScope(v8::MicrotaskQueue* queue)
: microtask_queue_(queue), original_policy_(queue->microtasks_policy()) {
// In browser-like processes, some nested run loops (macOS usually) may

View File

@@ -66,6 +66,8 @@ node::Environment* CreateEnvironment(v8::Isolate* isolate,
node::EnvironmentFlags::Flags env_flags,
std::string_view process_type = "");
v8::Local<v8::Object> CreateAbortController(v8::Isolate* isolate);
// A scope that temporarily changes the microtask policy to explicit. Use this
// anywhere that can trigger Node.js or uv_run().
//

View File

@@ -11,8 +11,10 @@
#include "base/no_destructor.h"
#include "base/process/process.h"
#include "base/strings/utf_string_conversions.h"
#include "electron/buildflags/buildflags.h"
#include "electron/fuses.h"
#include "electron/mas.h"
#include "mojo/public/cpp/bindings/self_owned_receiver.h"
#include "net/base/network_change_notifier.h"
#include "services/network/public/cpp/wrapper_shared_url_loader_factory.h"
#include "services/network/public/mojom/host_resolver.mojom.h"
@@ -30,6 +32,12 @@
#include "shell/common/crash_keys.h"
#endif
#if BUILDFLAG(ENABLE_PROMPT_API)
#include "shell/utility/ai/utility_ai_manager.h"
#include "url/gurl.h"
#include "url/origin.h"
#endif // BUILDFLAG(ENABLE_PROMPT_API)
namespace electron {
mojo::Remote<node::mojom::NodeServiceClient>& GetRemote() {
@@ -215,4 +223,15 @@ void NodeService::UpdateURLLoaderFactory(
params->use_network_observer_from_url_loader_factory);
}
#if BUILDFLAG(ENABLE_PROMPT_API)
void NodeService::BindAIManager(
node::mojom::BindAIManagerParamsPtr params,
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager) {
mojo::MakeSelfOwnedReceiver(
std::make_unique<UtilityAIManager>(params->web_contents_id,
params->security_origin),
std::move(ai_manager));
}
#endif // BUILDFLAG(ENABLE_PROMPT_API)
} // namespace electron

View File

@@ -7,6 +7,7 @@
#include <memory>
#include "electron/buildflags/buildflags.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver.h"
@@ -68,6 +69,12 @@ class NodeService : public node::mojom::NodeService {
void UpdateURLLoaderFactory(
node::mojom::URLLoaderFactoryParamsPtr params) override;
#if BUILDFLAG(ENABLE_PROMPT_API)
void BindAIManager(
node::mojom::BindAIManagerParamsPtr params,
mojo::PendingReceiver<blink::mojom::AIManager> ai_manager) override;
#endif // BUILDFLAG(ENABLE_PROMPT_API)
private:
// This needs to be initialized first so that it can be destroyed last
// after the node::Environment is destroyed. This ensures that if

View File

@@ -2,6 +2,7 @@
# Use of this source code is governed by the MIT license that can be
# found in the LICENSE file.
import("//electron/buildflags/buildflags.gni")
import("//mojo/public/tools/bindings/mojom.gni")
mojom("mojom") {
@@ -11,4 +12,8 @@ mojom("mojom") {
"//sandbox/policy/mojom",
"//third_party/blink/public/mojom:mojom_core",
]
if (enable_prompt_api) {
enabled_features = [ "enable_prompt_api" ]
}
}

View File

@@ -8,7 +8,9 @@ import "mojo/public/mojom/base/file_path.mojom";
import "sandbox/policy/mojom/sandbox.mojom";
import "services/network/public/mojom/host_resolver.mojom";
import "services/network/public/mojom/url_loader_factory.mojom";
import "third_party/blink/public/mojom/ai/ai_manager.mojom";
import "third_party/blink/public/mojom/messaging/message_port_descriptor.mojom";
import "url/mojom/origin.mojom";
struct URLLoaderFactoryParams {
pending_remote<network.mojom.URLLoaderFactory> url_loader_factory;
@@ -24,6 +26,11 @@ struct NodeServiceParams {
URLLoaderFactoryParams url_loader_factory_params;
};
struct BindAIManagerParams {
int32? web_contents_id;
url.mojom.Origin security_origin;
};
interface NodeServiceClient {
OnV8FatalError(string location, string report);
};
@@ -34,4 +41,8 @@ interface NodeService {
pending_remote<NodeServiceClient> client_remote);
UpdateURLLoaderFactory(URLLoaderFactoryParams params);
[EnableIf=enable_prompt_api]
BindAIManager(BindAIManagerParams params,
pending_receiver<blink.mojom.AIManager> ai_manager);
};

View File

@@ -0,0 +1,726 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#include "shell/utility/ai/utility_ai_language_model.h"
#include <string_view>
#include "base/no_destructor.h"
#include "base/notimplemented.h"
#include "shell/browser/javascript_environment.h"
#include "shell/common/gin_converters/callback_converter.h"
#include "shell/common/gin_converters/std_converter.h"
#include "shell/common/gin_helper/dictionary.h"
#include "shell/common/gin_helper/event_emitter_caller.h"
#include "shell/common/node_includes.h"
#include "shell/common/node_util.h"
#include "shell/common/v8_util.h"
#include "shell/utility/ai/utility_ai_manager.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
namespace gin {
template <>
struct Converter<on_device_model::mojom::ResponseConstraintPtr> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
const on_device_model::mojom::ResponseConstraintPtr& val) {
if (val.is_null())
return v8::Undefined(isolate);
if (val->is_json_schema()) {
return v8::JSON::Parse(isolate->GetCurrentContext(),
StringToV8(isolate, val->get_json_schema()))
.ToLocalChecked();
} else if (val->is_regex()) {
return v8::RegExp::New(isolate->GetCurrentContext(),
StringToV8(isolate, val->get_regex()),
v8::RegExp::kNone)
.ToLocalChecked();
}
return v8::Undefined(isolate);
}
};
template <>
struct Converter<blink::mojom::AILanguageModelPromptRole> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
blink::mojom::AILanguageModelPromptRole value) {
switch (value) {
case blink::mojom::AILanguageModelPromptRole::kSystem:
return StringToV8(isolate, "system");
case blink::mojom::AILanguageModelPromptRole::kUser:
return StringToV8(isolate, "user");
case blink::mojom::AILanguageModelPromptRole::kAssistant:
return StringToV8(isolate, "assistant");
default:
return StringToV8(isolate, "unknown");
}
}
};
template <>
struct Converter<blink::mojom::AILanguageModelPromptContentPtr> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
const blink::mojom::AILanguageModelPromptContentPtr& val) {
if (val.is_null())
return v8::Undefined(isolate);
auto dict = gin::Dictionary::CreateEmpty(isolate);
if (val->is_text()) {
dict.Set("type", "text");
dict.Set("value", val->get_text());
} else if (val->is_bitmap()) {
// Convert the bitmap to an ArrayBuffer
// TODO - Are we going to make any guarantees about the shape of the image
// data?
SkBitmap& bitmap = val->get_bitmap();
const auto dst_info = SkImageInfo::MakeN32Premul(bitmap.dimensions());
const size_t dst_n_bytes = dst_info.computeMinByteSize();
auto dst_buf = v8::ArrayBuffer::New(isolate, dst_n_bytes);
if (!bitmap.readPixels(dst_info, dst_buf->Data(), dst_info.minRowBytes(),
0, 0)) {
auto err = v8::Exception::TypeError(
gin::StringToV8(isolate, "Invalid bitmap content in prompt"));
node::errors::TriggerUncaughtException(isolate, err, {});
}
dict.Set("type", "image");
dict.Set("value", dst_buf);
} else if (val->is_audio()) {
// Convert the audio data to an ArrayBuffer
// TODO - Are we going to make any guarantees about the shape of the audio
// data?
on_device_model::mojom::AudioDataPtr& audio_data = val->get_audio();
std::vector<float>& raw_data = audio_data->data;
const size_t dst_n_bytes =
sizeof(std::remove_reference_t<decltype(raw_data)>::value_type) *
raw_data.size();
auto dst_buf = v8::ArrayBuffer::New(isolate, dst_n_bytes);
UNSAFE_BUFFERS(
std::ranges::copy(raw_data, static_cast<char*>(dst_buf->Data())));
dict.Set("type", "audio");
dict.Set("value", dst_buf);
}
return ConvertToV8(isolate, dict);
}
};
v8::Local<v8::Value> Converter<blink::mojom::AILanguageModelPromptPtr>::ToV8(
v8::Isolate* isolate,
const blink::mojom::AILanguageModelPromptPtr& val) {
if (val.is_null())
return v8::Undefined(isolate);
auto dict = gin::Dictionary::CreateEmpty(isolate);
dict.Set("role", val->role);
dict.Set("content", val->content);
dict.Set("prefix", val->is_prefix);
return ConvertToV8(isolate, dict);
}
} // namespace gin
namespace electron {
namespace {
constexpr std::string_view kIsReadableStreamKey = "isReadableStream";
constexpr std::string_view kIsLanguageModelKey = "isLanguageModel";
constexpr std::string_view kIsLanguageModelClassKey = "isLanguageModelClass";
v8::Local<v8::Function> GetPrivateBoolean(v8::Isolate* const isolate,
const v8::Local<v8::Context>& context,
std::string_view key) {
auto binding_key = gin::StringToV8(isolate, key);
auto private_binding_key = v8::Private::ForApi(isolate, binding_key);
auto global_object = context->Global();
auto value =
global_object->GetPrivate(context, private_binding_key).ToLocalChecked();
if (value.IsEmpty() || !value->IsFunction()) {
LOG(FATAL) << "Attempted to get the '" << key
<< "' value but it was missing";
}
return value.As<v8::Function>();
}
bool IsReadableStream(v8::Isolate* isolate, v8::Local<v8::Value> val) {
static base::NoDestructor<v8::Global<v8::Function>> is_readable_stream;
auto context = isolate->GetCurrentContext();
if (is_readable_stream.get()->IsEmpty()) {
is_readable_stream->Reset(
isolate, GetPrivateBoolean(isolate, context, kIsReadableStreamKey));
}
v8::Local<v8::Value> args[] = {val};
v8::Local<v8::Value> result =
is_readable_stream->Get(isolate)
->Call(context, v8::Null(isolate), std::size(args), args)
.ToLocalChecked();
return result->IsBoolean() && result.As<v8::Boolean>()->Value();
}
uint64_t GetContextUsage(v8::Isolate* isolate,
v8::Local<v8::Object> language_model) {
auto context = isolate->GetCurrentContext();
v8::Local<v8::Value> val =
language_model->Get(context, gin::StringToV8(isolate, "contextUsage"))
.ToLocalChecked();
uint64_t token_count = 0;
if (val->IsNumber()) {
gin::ConvertFromV8(isolate, val, &token_count);
}
return token_count;
}
// Owns itself. Will live as long as there's more data to process
// and the Mojo remote is still connected.
class PromptResponder {
public:
PromptResponder(v8::Isolate* isolate,
v8::Local<v8::Value> value,
v8::Local<v8::Object> abort_controller,
v8::Local<v8::Object> language_model,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder,
UtilityAILanguageModel* model) {
abort_controller_.Reset(isolate, abort_controller);
language_model_.Reset(isolate, language_model);
responder_.Bind(std::move(pending_responder));
responder_.set_disconnect_handler(
base::BindOnce(&PromptResponder::DeleteThis, base::Unretained(this)));
destroy_subscription_ = model->AddDestroyObserver(base::BindRepeating(
&PromptResponder::OnModelDestroyed, base::Unretained(this)));
Respond(isolate, value);
}
// disable copy
PromptResponder(const PromptResponder&) = delete;
PromptResponder& operator=(const PromptResponder&) = delete;
private:
void OnModelDestroyed() {
// Drop the subscription since the model is already being destroyed.
destroy_subscription_ = {};
responder_->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
/*quota_error_info=*/nullptr);
DeleteThis();
}
void Respond(v8::Isolate* isolate, v8::Local<v8::Value> value) {
if (value->IsPromise()) {
auto promise = value.As<v8::Promise>();
auto then_cb = base::BindOnce(
[](base::WeakPtr<PromptResponder> weak_ptr, v8::Isolate* isolate,
v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->RespondImplementation(isolate, result);
}
},
weak_ptr_factory_.GetWeakPtr(), isolate);
auto catch_cb = base::BindOnce(
[](base::WeakPtr<PromptResponder> weak_ptr,
v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->SendError();
weak_ptr->DeleteThis();
}
},
weak_ptr_factory_.GetWeakPtr());
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
} else {
RespondImplementation(isolate, value);
}
}
void RespondImplementation(v8::Isolate* isolate, v8::Local<v8::Value> val) {
std::string response;
if (val->IsString() && gin::ConvertFromV8(isolate, val, &response)) {
responder_->OnStreaming(response);
uint64_t token_count =
GetContextUsage(isolate, language_model_.Get(isolate));
responder_->OnCompletion(
blink::mojom::ModelExecutionContextInfo::New(token_count));
completed_ = true;
DeleteThis();
} else if (IsReadableStream(isolate, val)) {
v8::Local<v8::Value> reader =
gin_helper::CallMethod(isolate, val.As<v8::Object>(), "getReader");
DCHECK(reader->IsObject());
readable_stream_reader_.Reset(isolate, reader.As<v8::Object>());
Read(isolate);
} else {
SendError();
DeleteThis();
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate, "Invalid return value from LanguageModel.prompt()"));
node::errors::TriggerUncaughtException(isolate, err, {});
}
}
void Read(v8::Isolate* isolate) {
v8::Local<v8::Value> val = gin_helper::CallMethod(
isolate, readable_stream_reader_.Get(isolate), "read");
DCHECK(val->IsPromise());
auto promise = val.As<v8::Promise>();
auto then_cb = base::BindOnce(
[](base::WeakPtr<PromptResponder> weak_ptr, v8::Isolate* isolate,
v8::Local<v8::Value> result) {
if (weak_ptr) {
CHECK(result->IsObject());
v8::Local<v8::Value> done =
result.As<v8::Object>()
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "done"))
.ToLocalChecked();
CHECK(done->IsBoolean());
if (done.As<v8::Boolean>()->Value()) {
uint64_t token_count = GetContextUsage(
isolate, weak_ptr->language_model_.Get(isolate));
weak_ptr->responder_->OnCompletion(
blink::mojom::ModelExecutionContextInfo::New(token_count));
weak_ptr->completed_ = true;
weak_ptr->DeleteThis();
} else {
v8::Local<v8::Value> val =
result.As<v8::Object>()
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "value"))
.ToLocalChecked();
DCHECK(val->IsString());
std::string value;
if (gin::ConvertFromV8(isolate, val, &value)) {
weak_ptr->responder_->OnStreaming(value);
weak_ptr->Read(isolate);
} else {
weak_ptr->SendError();
weak_ptr->DeleteThis();
}
}
}
},
weak_ptr_factory_.GetWeakPtr(), isolate);
auto catch_cb = base::BindOnce(
[](base::WeakPtr<PromptResponder> weak_ptr,
v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->SendError();
weak_ptr->DeleteThis();
}
},
weak_ptr_factory_.GetWeakPtr());
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
}
void SendError() {
responder_->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorUnknown,
/*quota_error_info=*/nullptr);
}
void DeleteThis() {
destroy_subscription_ = {};
weak_ptr_factory_.InvalidateWeakPtrs();
if (!completed_) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
if (!readable_stream_reader_.IsEmpty()) {
gin_helper::CallMethod(isolate, readable_stream_reader_.Get(isolate),
"cancel");
}
gin_helper::CallMethod(isolate, abort_controller_.Get(isolate), "abort");
}
delete this;
}
bool completed_ = false;
v8::Global<v8::Object> readable_stream_reader_;
v8::Global<v8::Object> abort_controller_;
v8::Global<v8::Object> language_model_;
mojo::Remote<blink::mojom::ModelStreamingResponder> responder_;
base::CallbackListSubscription destroy_subscription_;
base::WeakPtrFactory<PromptResponder> weak_ptr_factory_{this};
};
} // namespace
UtilityAILanguageModel::UtilityAILanguageModel(
v8::Local<v8::Object> language_model,
base::WeakPtr<UtilityAIManager> manager)
: manager_(std::move(manager)) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
language_model_.Reset(isolate, language_model);
responder_set_.set_disconnect_handler(
base::BindRepeating(&UtilityAILanguageModel::OnResponderDisconnect,
weak_ptr_factory_.GetWeakPtr()));
}
UtilityAILanguageModel::~UtilityAILanguageModel() {
if (!is_destroyed_) {
Destroy();
}
}
base::CallbackListSubscription UtilityAILanguageModel::AddDestroyObserver(
base::RepeatingClosure callback) {
return on_destroy_.Add(std::move(callback));
}
void UtilityAILanguageModel::OnResponderDisconnect(
mojo::RemoteSetElementId id) {
auto it = abort_controllers_.find(id);
if (it != abort_controllers_.end()) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
gin_helper::CallMethod(isolate, it->second.Get(isolate), "abort");
abort_controllers_.erase(it);
}
}
blink::mojom::ModelStreamingResponder* UtilityAILanguageModel::GetResponder(
mojo::RemoteSetElementId responder_id) {
return responder_set_.Get(responder_id);
}
// static
bool UtilityAILanguageModel::IsLanguageModel(v8::Isolate* isolate,
v8::Local<v8::Value> val) {
static base::NoDestructor<v8::Global<v8::Function>> is_language_model;
auto context = isolate->GetCurrentContext();
if (is_language_model.get()->IsEmpty()) {
is_language_model->Reset(
isolate, GetPrivateBoolean(isolate, context, kIsLanguageModelKey));
}
v8::Local<v8::Value> args[] = {val};
v8::Local<v8::Value> result =
is_language_model->Get(isolate)
->Call(context, v8::Null(isolate), std::size(args), args)
.ToLocalChecked();
return result->IsBoolean() && result.As<v8::Boolean>()->Value();
}
// static
bool UtilityAILanguageModel::IsLanguageModelClass(v8::Isolate* isolate,
v8::Local<v8::Value> val) {
static base::NoDestructor<v8::Global<v8::Function>> is_language_model_class;
auto context = isolate->GetCurrentContext();
if (is_language_model_class.get()->IsEmpty()) {
is_language_model_class->Reset(
isolate, GetPrivateBoolean(isolate, context, kIsLanguageModelClassKey));
}
v8::Local<v8::Value> args[] = {val};
v8::Local<v8::Value> result =
is_language_model_class->Get(isolate)
->Call(context, v8::Null(isolate), std::size(args), args)
.ToLocalChecked();
return result->IsBoolean() && result.As<v8::Boolean>()->Value();
}
void UtilityAILanguageModel::Prompt(
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
on_device_model::mojom::ResponseConstraintPtr constraint,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
if (is_destroyed_) {
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
/*quota_error_info=*/nullptr);
return;
}
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
auto options = gin_helper::Dictionary::CreateEmpty(isolate);
if (!constraint.is_null()) {
options.Set("responseConstraint", gin::ConvertToV8(isolate, constraint));
}
options.Set("signal", abort_controller
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "signal"))
.ToLocalChecked());
v8::Local<v8::Value> val = gin_helper::CallMethod(
isolate, language_model_.Get(isolate), "prompt", prompts, options);
new PromptResponder(isolate, val, abort_controller,
language_model_.Get(isolate),
std::move(pending_responder), this);
}
void UtilityAILanguageModel::Append(
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
if (is_destroyed_) {
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
/*quota_error_info=*/nullptr);
return;
}
mojo::RemoteSetElementId responder_id =
responder_set_.Add(std::move(pending_responder));
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
abort_controllers_.emplace(responder_id,
v8::Global<v8::Object>(isolate, abort_controller));
auto options = gin_helper::Dictionary::CreateEmpty(isolate);
options.Set("signal", abort_controller
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "signal"))
.ToLocalChecked());
v8::Local<v8::Value> val = gin_helper::CallMethod(
isolate, language_model_.Get(isolate), "append", prompts, options);
auto SendResponse =
[](base::WeakPtr<UtilityAILanguageModel> weak_ptr, v8::Isolate* isolate,
mojo::RemoteSetElementId responder_id, v8::Local<v8::Value> result) {
if (!weak_ptr)
return;
weak_ptr->abort_controllers_.erase(responder_id);
blink::mojom::ModelStreamingResponder* responder =
weak_ptr->GetResponder(responder_id);
if (!responder) {
return;
}
uint64_t token_count =
GetContextUsage(isolate, weak_ptr->language_model_.Get(isolate));
responder->OnCompletion(
blink::mojom::ModelExecutionContextInfo::New(token_count));
};
if (val->IsPromise()) {
auto promise = val.As<v8::Promise>();
auto then_cb = base::BindOnce(SendResponse, weak_ptr_factory_.GetWeakPtr(),
isolate, responder_id);
auto catch_cb = base::BindOnce(
[](base::WeakPtr<UtilityAILanguageModel> weak_ptr,
mojo::RemoteSetElementId responder_id, v8::Local<v8::Value> result) {
if (!weak_ptr)
return;
weak_ptr->abort_controllers_.erase(responder_id);
blink::mojom::ModelStreamingResponder* responder =
weak_ptr->GetResponder(responder_id);
if (!responder) {
return;
}
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorUnknown,
/*quota_error_info=*/nullptr);
},
weak_ptr_factory_.GetWeakPtr(), responder_id);
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
} else {
// The method is supposed to return a promise, but for
// convenience allow developers to return a value directly
SendResponse(weak_ptr_factory_.GetWeakPtr(), isolate, responder_id, val);
}
}
void UtilityAILanguageModel::Fork(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client) {
if (is_destroyed_ || !manager_) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote(std::move(client));
client_remote->OnError(
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession,
/*quota_error_info=*/nullptr);
return;
}
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
manager_->CreateLanguageModelInternal(
isolate, std::move(client), language_model_.Get(isolate), "clone",
gin_helper::Dictionary::CreateEmpty(isolate),
blink::mojom::AILanguageModelCreateOptions::New());
}
void UtilityAILanguageModel::Destroy() {
if (is_destroyed_) {
return;
}
is_destroyed_ = true;
// Notify observers (e.g. in-progress PromptResponders) before
// tearing down the responder set and abort controllers.
on_destroy_.Notify();
for (auto& responder : responder_set_) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorSessionDestroyed,
/*quota_error_info=*/nullptr);
}
responder_set_.Clear();
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
for (auto& [id, controller] : abort_controllers_) {
gin_helper::CallMethod(isolate, controller.Get(isolate), "abort");
}
abort_controllers_.clear();
for (auto& controller : measure_abort_controllers_) {
gin_helper::CallMethod(isolate, controller.Get(isolate), "abort");
}
measure_abort_controllers_.clear();
gin_helper::CallMethod(isolate, language_model_.Get(isolate), "destroy");
}
void UtilityAILanguageModel::MeasureInputUsage(
std::vector<blink::mojom::AILanguageModelPromptPtr> input,
MeasureInputUsageCallback callback) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
measure_abort_controllers_.emplace_back(isolate, abort_controller);
auto abort_it = std::prev(measure_abort_controllers_.end());
auto options = gin_helper::Dictionary::CreateEmpty(isolate);
options.Set("signal", abort_controller
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "signal"))
.ToLocalChecked());
v8::Local<v8::Value> val =
gin_helper::CallMethod(isolate, language_model_.Get(isolate),
"measureContextUsage", input, options);
auto RunCallback = [](base::WeakPtr<UtilityAILanguageModel> weak_ptr,
std::list<v8::Global<v8::Object>>::iterator abort_it,
v8::Isolate* isolate,
MeasureInputUsageCallback callback,
v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->measure_abort_controllers_.erase(abort_it);
}
uint32_t input_tokens = 0;
if (result->IsNumber() &&
gin::ConvertFromV8(isolate, result, &input_tokens)) {
std::move(callback).Run(std::move(input_tokens));
} else if (result->IsNull()) {
std::move(callback).Run(std::nullopt);
} else {
std::move(callback).Run(std::nullopt);
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate,
"Invalid return value from LanguageModel.measureContextUsage()"));
node::errors::TriggerUncaughtException(isolate, err, {});
}
};
if (val->IsPromise()) {
auto promise = val.As<v8::Promise>();
auto split_callback = base::SplitOnceCallback(std::move(callback));
auto then_cb =
base::BindOnce(RunCallback, weak_ptr_factory_.GetWeakPtr(), abort_it,
isolate, std::move(split_callback.first));
auto catch_cb = base::BindOnce(
[](base::WeakPtr<UtilityAILanguageModel> weak_ptr,
std::list<v8::Global<v8::Object>>::iterator abort_it,
MeasureInputUsageCallback callback, v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->measure_abort_controllers_.erase(abort_it);
}
std::move(callback).Run(std::nullopt);
},
weak_ptr_factory_.GetWeakPtr(), abort_it,
std::move(split_callback.second));
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
} else {
// The method is supposed to return a promise, but for
// convenience allow developers to return a value directly
RunCallback(weak_ptr_factory_.GetWeakPtr(), abort_it, isolate,
std::move(callback), val);
}
}
} // namespace electron

View File

@@ -0,0 +1,99 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#ifndef ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_LANGUAGE_MODEL_H_
#define ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_LANGUAGE_MODEL_H_
#include <list>
#include <vector>
#include "base/callback_list.h"
#include "base/memory/weak_ptr.h"
#include "gin/converter.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "v8/include/v8.h"
namespace electron {
class UtilityAIManager;
class UtilityAILanguageModel : public blink::mojom::AILanguageModel {
public:
UtilityAILanguageModel(v8::Local<v8::Object> language_model,
base::WeakPtr<UtilityAIManager> manager);
UtilityAILanguageModel(const UtilityAILanguageModel&) = delete;
UtilityAILanguageModel& operator=(const UtilityAILanguageModel&) = delete;
~UtilityAILanguageModel() override;
// Subscribe to be notified when this model is destroyed. The returned
// subscription auto-unregisters when destroyed.
[[nodiscard]] base::CallbackListSubscription AddDestroyObserver(
base::RepeatingClosure callback);
static bool IsLanguageModel(v8::Isolate* isolate, v8::Local<v8::Value> val);
static bool IsLanguageModelClass(v8::Isolate* isolate,
v8::Local<v8::Value> val);
// `blink::mojom::AILanguageModel` implementation.
void Prompt(std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
on_device_model::mojom::ResponseConstraintPtr constraint,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) override;
void Append(std::vector<blink::mojom::AILanguageModelPromptPtr> prompts,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) override;
void Fork(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client) override;
void Destroy() override;
void MeasureInputUsage(
std::vector<blink::mojom::AILanguageModelPromptPtr> input,
MeasureInputUsageCallback callback) override;
private:
void OnResponderDisconnect(mojo::RemoteSetElementId id);
blink::mojom::ModelStreamingResponder* GetResponder(
mojo::RemoteSetElementId responder_id);
base::WeakPtr<UtilityAIManager> manager_;
v8::Global<v8::Object> language_model_;
bool is_destroyed_ = false;
mojo::RemoteSet<blink::mojom::ModelStreamingResponder> responder_set_;
// Maps each in-progress Prompt/Append responder to its AbortController
// so we can abort the JS-side operation if the responder disconnects.
absl::flat_hash_map<mojo::RemoteSetElementId, v8::Global<v8::Object>>
abort_controllers_;
// Tracks abort controllers for in-progress MeasureInputUsage calls.
std::list<v8::Global<v8::Object>> measure_abort_controllers_;
// Notified when this model is destroyed, allowing in-progress
// PromptResponder instances to clean up.
base::RepeatingClosureList on_destroy_;
base::WeakPtrFactory<UtilityAILanguageModel> weak_ptr_factory_{this};
};
} // namespace electron
namespace gin {
template <>
struct Converter<blink::mojom::AILanguageModelPromptPtr> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
const blink::mojom::AILanguageModelPromptPtr& val);
};
} // namespace gin
#endif // ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_LANGUAGE_MODEL_H_

View File

@@ -0,0 +1,523 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#include "shell/utility/ai/utility_ai_manager.h"
#include <optional>
#include <utility>
#include "base/containers/fixed_flat_map.h"
#include "base/notimplemented.h"
#include "mojo/public/cpp/bindings/unique_receiver_set.h"
#include "shell/browser/javascript_environment.h"
#include "shell/common/gin_converters/callback_converter.h"
#include "shell/common/gin_converters/std_converter.h"
#include "shell/common/gin_helper/dictionary.h"
#include "shell/common/gin_helper/event_emitter_caller.h"
#include "shell/common/node_includes.h"
#include "shell/common/node_util.h"
#include "shell/utility/ai/utility_ai_language_model.h"
#include "shell/utility/api/electron_api_local_ai_handler.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_writer.mojom.h"
#include "url/gurl.h"
#include "url/origin.h"
#include "v8/include/v8.h"
namespace gin {
template <>
struct Converter<blink::mojom::ModelAvailabilityCheckResult> {
static bool FromV8(v8::Isolate* isolate,
v8::Local<v8::Value> val,
blink::mojom::ModelAvailabilityCheckResult* out) {
using Result = blink::mojom::ModelAvailabilityCheckResult;
static constexpr auto Lookup =
base::MakeFixedFlatMap<std::string_view, Result>({
{"available", Result::kAvailable},
{"unavailable", Result::kUnavailableUnknown},
{"downloading", Result::kDownloading},
{"downloadable", Result::kDownloadable},
});
return FromV8WithLookup(isolate, val, Lookup, out);
}
};
template <>
struct Converter<blink::mojom::AILanguageModelPromptType> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
blink::mojom::AILanguageModelPromptType value) {
switch (value) {
case blink::mojom::AILanguageModelPromptType::kText:
return StringToV8(isolate, "text");
case blink::mojom::AILanguageModelPromptType::kImage:
return StringToV8(isolate, "image");
case blink::mojom::AILanguageModelPromptType::kAudio:
return StringToV8(isolate, "audio");
default:
return StringToV8(isolate, "unknown");
}
}
};
template <>
struct Converter<blink::mojom::AILanguageCodePtr> {
static v8::Local<v8::Value> ToV8(v8::Isolate* isolate,
const blink::mojom::AILanguageCodePtr& val) {
if (val.is_null()) {
return v8::Undefined(isolate);
}
return StringToV8(isolate, val->code);
}
};
template <>
struct Converter<blink::mojom::AILanguageModelExpectedPtr> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
const blink::mojom::AILanguageModelExpectedPtr& val) {
if (val.is_null()) {
return v8::Undefined(isolate);
}
auto dict = gin::Dictionary::CreateEmpty(isolate);
dict.Set("type", val->type);
if (val->languages.has_value() && !val->languages->empty()) {
dict.Set("languages", val->languages.value());
}
return ConvertToV8(isolate, dict);
}
};
template <>
struct Converter<blink::mojom::AILanguageModelCreateOptionsPtr> {
static v8::Local<v8::Value> ToV8(
v8::Isolate* isolate,
const blink::mojom::AILanguageModelCreateOptionsPtr& val) {
if (val.is_null() ||
(val->sampling_params.is_null() && !val->expected_inputs.has_value() &&
!val->expected_outputs.has_value() && val->initial_prompts.empty())) {
return v8::Undefined(isolate);
}
auto dict = gin::Dictionary::CreateEmpty(isolate);
if (val->expected_inputs.has_value() && !val->expected_inputs->empty()) {
dict.Set("expectedInputs", val->expected_inputs.value());
}
if (val->expected_outputs.has_value() && !val->expected_outputs->empty()) {
dict.Set("expectedOutputs", val->expected_outputs.value());
}
if (!val->initial_prompts.empty()) {
dict.Set("initialPrompts", val->initial_prompts);
}
return ConvertToV8(isolate, dict);
}
};
} // namespace gin
namespace electron {
UtilityAIManager::UtilityAIManager(std::optional<int32_t> web_contents_id,
const url::Origin& security_origin)
: web_contents_id_(web_contents_id), security_origin_(security_origin) {
create_model_client_set_.set_disconnect_with_reason_handler(
base::BindRepeating(
&UtilityAIManager::OnCreateLanguageModelClientDisconnect,
weak_ptr_factory_.GetWeakPtr()));
}
UtilityAIManager::~UtilityAIManager() {
// Trigger the abort signal for any in-progress CreateLanguageModel calls
for (auto it = create_model_client_set_.begin();
it != create_model_client_set_.end(); ++it) {
OnCreateLanguageModelClientDisconnect(it.id(), 0, std::string());
}
}
void UtilityAIManager::OnCreateLanguageModelClientDisconnect(
mojo::RemoteSetElementId id,
uint32_t custom_reason,
const std::string& description) {
auto it = abort_controllers_.find(id);
if (it != abort_controllers_.end()) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
if (description.empty()) {
gin_helper::CallMethod(isolate, it->second.Get(isolate), "abort");
} else {
gin_helper::CallMethod(isolate, it->second.Get(isolate), "abort",
description);
}
abort_controllers_.erase(it);
}
}
v8::Global<v8::Object>& UtilityAIManager::GetLanguageModelClass() {
if (language_model_class_.IsEmpty()) {
auto& handler = electron::api::local_ai_handler::GetPromptAPIHandler();
if (handler.has_value()) {
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
auto details = gin_helper::Dictionary::CreateEmpty(isolate);
if (web_contents_id_.has_value()) {
details.Set("webContentsId", web_contents_id_.value());
} else {
details.Set("webContentsId", nullptr);
}
details.Set("securityOrigin", security_origin_.Serialize());
v8::Local<v8::Value> val = handler->Run(details);
if (val->IsPromise()) {
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate, "Cannot return a promise from the handler"));
node::errors::TriggerUncaughtException(isolate, err, {});
return language_model_class_;
}
if (!val->IsObject() ||
!val->ToObject(isolate->GetCurrentContext())
.ToLocalChecked()
->IsConstructor() ||
!UtilityAILanguageModel::IsLanguageModelClass(isolate, val)) {
auto err = v8::Exception::TypeError(
gin::StringToV8(isolate, "Must provide a constructible class"));
node::errors::TriggerUncaughtException(isolate, err, {});
return language_model_class_;
} else {
language_model_class_.Reset(
isolate,
val->ToObject(isolate->GetCurrentContext()).ToLocalChecked());
}
}
}
return language_model_class_;
}
void UtilityAIManager::SendCreateLanguageModelError(
mojo::RemoteSetElementId client_id,
blink::mojom::AIManagerCreateClientError error) {
abort_controllers_.erase(client_id);
blink::mojom::AIManagerCreateLanguageModelClient* client =
create_model_client_set_.Get(client_id);
if (!client) {
return;
}
client->OnError(error, /*quota_error_info=*/nullptr);
}
void UtilityAIManager::HandleLanguageModelResult(
v8::Isolate* isolate,
v8::Local<v8::Object> language_model,
mojo::RemoteSetElementId client_id,
blink::mojom::AILanguageModelCreateOptionsPtr options) {
abort_controllers_.erase(client_id);
gin_helper::Dictionary dict;
uint64_t context_usage = 0;
uint64_t context_quota = 0;
if (!ConvertFromV8(isolate, language_model, &dict) ||
!dict.Get("contextUsage", &context_usage) ||
!dict.Get("contextWindow", &context_quota)) {
SendCreateLanguageModelError(
client_id,
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
return;
}
// TODO - How can the implementation specify the supported prompt types? For
// now, assume all types are supported if the handler returns a valid object
base::flat_set<blink::mojom::AILanguageModelPromptType> enabled_input_types;
if (options->expected_inputs.has_value()) {
for (const auto& expected_input : options->expected_inputs.value()) {
enabled_input_types.insert(expected_input->type);
}
}
blink::mojom::AIManagerCreateLanguageModelClient* client =
create_model_client_set_.Get(client_id);
if (!client) {
return;
}
mojo::PendingRemote<blink::mojom::AILanguageModel> language_model_remote;
language_model_receivers_.Add(
std::make_unique<UtilityAILanguageModel>(language_model,
weak_ptr_factory_.GetWeakPtr()),
language_model_remote.InitWithNewPipeAndPassReceiver());
client->OnResult(
std::move(language_model_remote),
blink::mojom::AILanguageModelInstanceInfo::New(
context_quota, context_usage,
blink::mojom::AILanguageModelSamplingParams::New(),
std::vector<blink::mojom::AILanguageModelPromptType>(
enabled_input_types.begin(), enabled_input_types.end()),
/*audio_sample_rate_hz=*/std::nullopt,
/*audio_channel_count=*/std::nullopt));
}
void UtilityAIManager::CanCreateLanguageModel(
blink::mojom::AILanguageModelCreateOptionsPtr options,
CanCreateLanguageModelCallback callback) {
v8::Global<v8::Object>& language_model_class = GetLanguageModelClass();
blink::mojom::ModelAvailabilityCheckResult availability =
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown;
if (language_model_class.IsEmpty()) {
std::move(callback).Run(availability);
} else {
// If a handler is set, we can create a language model.
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
v8::Local<v8::Value> val = gin_helper::CallMethod(
isolate, language_model_class.Get(isolate), "availability", options);
auto RunCallback = [](v8::Isolate* isolate,
CanCreateLanguageModelCallback callback,
v8::Local<v8::Value> result) {
blink::mojom::ModelAvailabilityCheckResult availability =
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown;
if (result->IsString() &&
gin::ConvertFromV8(isolate, result, &availability)) {
std::move(callback).Run(availability);
} else {
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate, "Invalid return value from LanguageModel.availability()"));
node::errors::TriggerUncaughtException(isolate, err, {});
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
};
if (val->IsPromise()) {
auto promise = val.As<v8::Promise>();
auto split_callback = base::SplitOnceCallback(std::move(callback));
auto then_cb =
base::BindOnce(RunCallback, isolate, std::move(split_callback.first));
auto catch_cb = base::BindOnce(
[](CanCreateLanguageModelCallback callback,
v8::Local<v8::Value> result) {
std::move(callback).Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableUnknown);
},
std::move(split_callback.second));
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
} else {
// The method is supposed to return a promise, but for
// convenience allow developers to return a value directly
RunCallback(isolate, std::move(callback), val);
}
}
}
void UtilityAIManager::CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) {
v8::Global<v8::Object>& language_model_class = GetLanguageModelClass();
// Can't create language model if there's no language model class
if (language_model_class.IsEmpty()) {
mojo::Remote<blink::mojom::AIManagerCreateLanguageModelClient>
client_remote(std::move(client));
client_remote->OnError(
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession,
/*quota_error_info=*/nullptr);
return;
}
v8::Isolate* isolate = JavascriptEnvironment::GetIsolate();
v8::HandleScope scope{isolate};
gin_helper::Dictionary options_dict{
isolate, gin::ConvertToV8(isolate, options).As<v8::Object>()};
CreateLanguageModelInternal(isolate, std::move(client),
language_model_class.Get(isolate), "create",
std::move(options_dict), std::move(options));
}
void UtilityAIManager::CreateLanguageModelInternal(
v8::Isolate* isolate,
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
v8::Local<v8::Object> target,
std::string_view method_name,
gin_helper::Dictionary options_dict,
blink::mojom::AILanguageModelCreateOptionsPtr options) {
DCHECK(method_name == "create" || method_name == "clone");
std::string error_source = "LanguageModel." + std::string(method_name) + "()";
mojo::RemoteSetElementId client_id =
create_model_client_set_.Add(std::move(client));
// Store the abort controller so the disconnect handler can abort it.
v8::Local<v8::Object> abort_controller = util::CreateAbortController(isolate);
abort_controllers_.emplace(client_id,
v8::Global<v8::Object>(isolate, abort_controller));
options_dict.Set("signal", abort_controller
->Get(isolate->GetCurrentContext(),
gin::StringToV8(isolate, "signal"))
.ToLocalChecked());
v8::Local<v8::Value> val =
gin_helper::CallMethod(isolate, target, method_name.data(), options_dict);
if (val->IsPromise()) {
auto promise = val.As<v8::Promise>();
auto then_cb = base::BindOnce(
[](base::WeakPtr<UtilityAIManager> weak_ptr, v8::Isolate* isolate,
mojo::RemoteSetElementId client_id,
blink::mojom::AILanguageModelCreateOptionsPtr options,
std::string error_source, v8::Local<v8::Value> result) {
if (weak_ptr) {
if (result->IsObject() &&
UtilityAILanguageModel::IsLanguageModel(isolate, result)) {
weak_ptr->HandleLanguageModelResult(
isolate, result.As<v8::Object>(), client_id,
std::move(options));
} else {
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate, "Invalid return value from " + error_source));
node::errors::TriggerUncaughtException(isolate, err, {});
weak_ptr->SendCreateLanguageModelError(
client_id, blink::mojom::AIManagerCreateClientError::
kUnableToCreateSession);
}
}
},
weak_ptr_factory_.GetWeakPtr(), isolate, client_id, std::move(options),
std::string(error_source));
auto catch_cb = base::BindOnce(
[](base::WeakPtr<UtilityAIManager> weak_ptr,
mojo::RemoteSetElementId client_id, v8::Local<v8::Value> result) {
if (weak_ptr) {
weak_ptr->SendCreateLanguageModelError(
client_id, blink::mojom::AIManagerCreateClientError::
kUnableToCreateSession);
}
},
weak_ptr_factory_.GetWeakPtr(), client_id);
std::ignore = promise->Then(
isolate->GetCurrentContext(),
gin::ConvertToV8(isolate, std::move(then_cb)).As<v8::Function>(),
gin::ConvertToV8(isolate, std::move(catch_cb)).As<v8::Function>());
} else if (val->IsObject() &&
UtilityAILanguageModel::IsLanguageModel(isolate, val)) {
// The method is supposed to return a promise, but for
// convenience allow developers to return a value directly
HandleLanguageModelResult(isolate, val.As<v8::Object>(), client_id,
std::move(options));
} else {
auto err = v8::Exception::TypeError(gin::StringToV8(
isolate, "Invalid return value from " + std::string(error_source)));
node::errors::TriggerUncaughtException(isolate, err, {});
SendCreateLanguageModelError(
client_id,
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
}
}
void UtilityAIManager::CanCreateSummarizer(
blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void UtilityAIManager::CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void UtilityAIManager::GetLanguageModelParams(
GetLanguageModelParamsCallback callback) {
NOTIMPLEMENTED();
}
void UtilityAIManager::CanCreateWriter(
blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void UtilityAIManager::CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void UtilityAIManager::CanCreateRewriter(
blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void UtilityAIManager::CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void UtilityAIManager::CanCreateProofreader(
blink::mojom::AIProofreaderCreateOptionsPtr options,
CanCreateProofreaderCallback callback) {
std::move(callback).Run(
blink::mojom::ModelAvailabilityCheckResult::kUnavailableUnknown);
}
void UtilityAIManager::CreateProofreader(
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient> client,
blink::mojom::AIProofreaderCreateOptionsPtr options) {
NOTIMPLEMENTED();
}
void UtilityAIManager::AddModelDownloadProgressObserver(
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
observer_remote) {
NOTIMPLEMENTED();
}
} // namespace electron

View File

@@ -0,0 +1,125 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#ifndef ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_MANAGER_H_
#define ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_MANAGER_H_
#include <optional>
#include <string>
#include <string_view>
#include "base/memory/weak_ptr.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "mojo/public/cpp/bindings/unique_receiver_set.h"
#include "services/on_device_model/public/mojom/download_observer.mojom-forward.h"
#include "shell/common/gin_helper/dictionary.h"
#include "third_party/abseil-cpp/absl/container/flat_hash_map.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_proofreader.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_writer.mojom-forward.h"
#include "url/origin.h"
#include "v8/include/v8.h"
namespace electron {
class UtilityAILanguageModel;
// The utility-side implementation of `blink::mojom::AIManager`.
class UtilityAIManager : public blink::mojom::AIManager {
public:
UtilityAIManager(std::optional<int32_t> web_contents_id,
const url::Origin& security_origin);
UtilityAIManager(const UtilityAIManager&) = delete;
UtilityAIManager& operator=(const UtilityAIManager&) = delete;
~UtilityAIManager() override;
private:
friend class UtilityAILanguageModel;
void OnCreateLanguageModelClientDisconnect(mojo::RemoteSetElementId id,
uint32_t custom_reason,
const std::string& description);
[[nodiscard]] v8::Global<v8::Object>& GetLanguageModelClass();
void SendCreateLanguageModelError(
mojo::RemoteSetElementId client_id,
blink::mojom::AIManagerCreateClientError error);
void CreateLanguageModelInternal(
v8::Isolate* isolate,
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
v8::Local<v8::Object> target,
std::string_view method_name,
gin_helper::Dictionary options_dict,
blink::mojom::AILanguageModelCreateOptionsPtr options);
void HandleLanguageModelResult(
v8::Isolate* isolate,
v8::Local<v8::Object> language_model,
mojo::RemoteSetElementId client_id,
blink::mojom::AILanguageModelCreateOptionsPtr options);
// `blink::mojom::AIManager` implementation.
void CanCreateLanguageModel(
blink::mojom::AILanguageModelCreateOptionsPtr options,
CanCreateLanguageModelCallback callback) override;
void CreateLanguageModel(
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
client,
blink::mojom::AILanguageModelCreateOptionsPtr options) override;
void CanCreateSummarizer(blink::mojom::AISummarizerCreateOptionsPtr options,
CanCreateSummarizerCallback callback) override;
void CreateSummarizer(
mojo::PendingRemote<blink::mojom::AIManagerCreateSummarizerClient> client,
blink::mojom::AISummarizerCreateOptionsPtr options) override;
void GetLanguageModelParams(GetLanguageModelParamsCallback callback) override;
void CanCreateWriter(blink::mojom::AIWriterCreateOptionsPtr options,
CanCreateWriterCallback callback) override;
void CreateWriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateWriterClient> client,
blink::mojom::AIWriterCreateOptionsPtr options) override;
void CanCreateRewriter(blink::mojom::AIRewriterCreateOptionsPtr options,
CanCreateRewriterCallback callback) override;
void CreateRewriter(
mojo::PendingRemote<blink::mojom::AIManagerCreateRewriterClient> client,
blink::mojom::AIRewriterCreateOptionsPtr options) override;
void CanCreateProofreader(blink::mojom::AIProofreaderCreateOptionsPtr options,
CanCreateProofreaderCallback callback) override;
void CreateProofreader(
mojo::PendingRemote<blink::mojom::AIManagerCreateProofreaderClient>
client,
blink::mojom::AIProofreaderCreateOptionsPtr options) override;
void AddModelDownloadProgressObserver(
mojo::PendingRemote<on_device_model::mojom::DownloadObserver>
observer_remote) override;
std::optional<int32_t> web_contents_id_;
url::Origin security_origin_;
v8::Global<v8::Object> language_model_class_;
mojo::RemoteSet<blink::mojom::AIManagerCreateLanguageModelClient>
create_model_client_set_;
// Maps each in-progress CreateLanguageModel client to its AbortController
// so we can abort the JS-side operation if the client disconnects.
absl::flat_hash_map<mojo::RemoteSetElementId, v8::Global<v8::Object>>
abort_controllers_;
// Owns all created UtilityAILanguageModel instances
mojo::UniqueReceiverSet<blink::mojom::AILanguageModel>
language_model_receivers_;
base::WeakPtrFactory<UtilityAIManager> weak_ptr_factory_{this};
};
} // namespace electron
#endif // ELECTRON_SHELL_UTILITY_AI_UTILITY_AI_MANAGER_H_

View File

@@ -0,0 +1,53 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#include "shell/utility/api/electron_api_local_ai_handler.h"
#include <optional>
#include "base/no_destructor.h"
#include "shell/common/gin_converters/callback_converter.h"
#include "shell/common/gin_helper/dictionary.h"
#include "shell/common/node_includes.h"
#include "v8/include/v8.h"
namespace electron::api::local_ai_handler {
void SetPromptAPIHandler(v8::Isolate* isolate, v8::Local<v8::Value> val) {
PromptAPIHandler handler;
if (!(val->IsNull() || gin::ConvertFromV8(isolate, val, &handler))) {
isolate->ThrowException(v8::Exception::TypeError(
gin::StringToV8(isolate, "Must pass null or function")));
return;
}
if (val->IsNull()) {
GetPromptAPIHandler() = std::nullopt;
} else {
GetPromptAPIHandler() = handler;
}
}
std::optional<PromptAPIHandler>& GetPromptAPIHandler() {
static base::NoDestructor<std::optional<PromptAPIHandler>> prompt_api_handler;
return *prompt_api_handler;
}
} // namespace electron::api::local_ai_handler
namespace {
void Initialize(v8::Local<v8::Object> exports,
v8::Local<v8::Value> unused,
v8::Local<v8::Context> context,
void* priv) {
v8::Isolate* const isolate = v8::Isolate::GetCurrent();
gin_helper::Dictionary dict{isolate, exports};
dict.SetMethod("setPromptAPIHandler",
&electron::api::local_ai_handler::SetPromptAPIHandler);
}
} // namespace
NODE_LINKED_BINDING_CONTEXT_AWARE(electron_utility_local_ai_handler, Initialize)

View File

@@ -0,0 +1,28 @@
// Copyright (c) 2025 Microsoft, Inc.
// Use of this source code is governed by the MIT license that can be
// found in the LICENSE file.
#ifndef ELECTRON_SHELL_UTILITY_API_ELECTRON_LOCAL_AI_HANDLER_H_
#define ELECTRON_SHELL_UTILITY_API_ELECTRON_LOCAL_AI_HANDLER_H_
#include <optional>
#include "base/functional/callback_forward.h"
#include "v8/include/v8-forward.h"
namespace gin_helper {
class Dictionary;
}
namespace electron::api::local_ai_handler {
using PromptAPIHandler =
base::RepeatingCallback<v8::Local<v8::Value>(gin_helper::Dictionary)>;
void SetPromptAPIHandler(v8::Isolate* isolate, v8::Local<v8::Value> value);
[[nodiscard]] std::optional<PromptAPIHandler>& GetPromptAPIHandler();
} // namespace electron::api::local_ai_handler
#endif // ELECTRON_SHELL_UTILITY_API_ELECTRON_LOCAL_AI_HANDLER_H_

View File

@@ -17,7 +17,7 @@ import * as nodeUrl from 'node:url';
import { emittedUntil, emittedNTimes } from './lib/events-helpers';
import { randomString } from './lib/net-helpers';
import { HexColors, hasCapturableScreen, ScreenCapture } from './lib/screen-helpers';
import { ifit, ifdescribe, defer, listen, waitUntil, isWayland } from './lib/spec-helpers';
import { ifit, ifdescribe, defer, listen, waitUntil } from './lib/spec-helpers';
import { closeWindow, closeAllWindows } from './lib/window-helpers';
const fixtures = path.resolve(__dirname, 'fixtures');
@@ -1204,45 +1204,7 @@ describe('BrowserWindow module', () => {
});
}
describe('visibility', () => {
let w: BrowserWindow;
beforeEach(() => {
w = new BrowserWindow({ show: false });
});
afterEach(async () => {
await closeWindow(w);
w = null as unknown as BrowserWindow;
});
describe('BrowserWindow.show()', () => {
it('should make the window visible', async () => {
const show = once(w, 'show');
w.show();
await show;
expect(w.isVisible()).to.equal(true);
});
});
describe('BrowserWindow.hide()', () => {
it('should make the window not visible', () => {
w.show();
w.hide();
expect(w.isVisible()).to.equal(false);
});
it('emits when window is hidden', async () => {
const shown = once(w, 'show');
w.show();
await shown;
const hidden = once(w, 'hide');
w.hide();
await hidden;
expect(w.isVisible()).to.equal(false);
});
});
});
// Wayland does not allow focus and z-order to be controlled without user input
ifdescribe(!isWayland)('focus, blur, and z-order', () => {
describe('focus and visibility', () => {
let w: BrowserWindow;
beforeEach(() => {
w = new BrowserWindow({ show: false });
@@ -1259,12 +1221,18 @@ describe('BrowserWindow module', () => {
await p;
expect(w.isFocused()).to.equal(true);
});
it('emits focus event and makes the window visible', async () => {
it('should make the window visible', async () => {
const p = once(w, 'focus');
w.show();
await p;
expect(w.isVisible()).to.equal(true);
});
it('emits when window is shown', async () => {
const show = once(w, 'show');
w.show();
await show;
expect(w.isVisible()).to.equal(true);
});
});
describe('BrowserWindow.hide()', () => {
@@ -1272,6 +1240,20 @@ describe('BrowserWindow module', () => {
w.hide();
expect(w.isFocused()).to.equal(false);
});
it('should make the window not visible', () => {
w.show();
w.hide();
expect(w.isVisible()).to.equal(false);
});
it('emits when window is hidden', async () => {
const shown = once(w, 'show');
w.show();
await shown;
const hidden = once(w, 'hide');
w.hide();
await hidden;
expect(w.isVisible()).to.equal(false);
});
});
describe('BrowserWindow.minimize()', () => {
@@ -1644,20 +1626,6 @@ describe('BrowserWindow module', () => {
await closeWindow(w2, { assertNotWindows: false });
});
});
describe('window.webContents.focus()', () => {
afterEach(closeAllWindows);
it('focuses window', async () => {
const w1 = new BrowserWindow({ x: 100, y: 300, width: 300, height: 200 });
w1.loadURL('about:blank');
const w2 = new BrowserWindow({ x: 300, y: 300, width: 300, height: 200 });
w2.loadURL('about:blank');
const w1Focused = once(w1, 'focus');
w1.webContents.focus();
await w1Focused;
expect(w1.webContents.isFocused()).to.be.true('focuses window');
});
});
});
describe('sizing', () => {
@@ -1846,8 +1814,7 @@ describe('BrowserWindow module', () => {
});
});
// Windows cannot be programmatically moved on Wayland
ifdescribe(!isWayland)('BrowserWindow.setContentBounds(bounds)', () => {
describe('BrowserWindow.setContentBounds(bounds)', () => {
it('sets the content size and position', async () => {
const bounds = { x: 10, y: 10, width: 250, height: 250 };
const resize = once(w, 'resize');
@@ -3318,19 +3285,6 @@ describe('BrowserWindow module', () => {
});
});
// On Wayland, hidden windows may not have mapped surfaces or finalized geometry
// until shown. Tests that depend on real geometry or frame events may need
// to show the window first.
const showWindowForWayland = async (w: BrowserWindow) => {
if (!isWayland || w.isVisible()) {
return;
}
const shown = once(w, 'show');
w.show();
await shown;
};
describe('"titleBarStyle" option', () => {
const testWindowsOverlay = async (style: any) => {
const w = new BrowserWindow({
@@ -3350,10 +3304,8 @@ describe('BrowserWindow module', () => {
} else {
const overlayReady = once(ipcMain, 'geometrychange');
await w.loadFile(overlayHTML);
await showWindowForWayland(w);
await overlayReady;
}
const overlayEnabled = await w.webContents.executeJavaScript('navigator.windowControlsOverlay.visible');
expect(overlayEnabled).to.be.true('overlayEnabled');
const overlayRect = await w.webContents.executeJavaScript('getJSOverlayProperties()');
@@ -3466,7 +3418,6 @@ describe('BrowserWindow module', () => {
} else {
const overlayReady = once(ipcMain, 'geometrychange');
await w.loadFile(overlayHTML);
await showWindowForWayland(w);
await overlayReady;
}
@@ -3540,7 +3491,6 @@ describe('BrowserWindow module', () => {
const overlayHTML = path.join(__dirname, 'fixtures', 'pages', 'overlay.html');
const overlayReady = once(ipcMain, 'geometrychange');
await w.loadFile(overlayHTML);
await showWindowForWayland(w);
if (firstRun) {
await overlayReady;
}
@@ -4821,9 +4771,7 @@ describe('BrowserWindow module', () => {
const w = new BrowserWindow({ show: false });
let called = false;
w.loadFile(path.join(fixtures, 'api', 'frame-subscriber.html'));
w.webContents.on('dom-ready', async () => {
await showWindowForWayland(w);
w.webContents.on('dom-ready', () => {
w.webContents.beginFrameSubscription(function () {
// This callback might be called twice.
if (called) return;
@@ -4843,9 +4791,7 @@ describe('BrowserWindow module', () => {
const w = new BrowserWindow({ show: false });
let called = false;
w.loadFile(path.join(fixtures, 'api', 'frame-subscriber.html'));
w.webContents.on('dom-ready', async () => {
await showWindowForWayland(w);
w.webContents.on('dom-ready', () => {
w.webContents.beginFrameSubscription(function (data) {
// This callback might be called twice.
if (called) return;
@@ -4869,9 +4815,7 @@ describe('BrowserWindow module', () => {
let called = false;
let gotInitialFullSizeFrame = false;
const [contentWidth, contentHeight] = w.getContentSize();
w.webContents.on('did-finish-load', async () => {
await showWindowForWayland(w);
w.webContents.on('did-finish-load', () => {
w.webContents.beginFrameSubscription(true, (image, rect) => {
if (image.isEmpty()) {
// Chromium sometimes sends a 0x0 frame at the beginning of the
@@ -5452,57 +5396,55 @@ describe('BrowserWindow module', () => {
await createTwo();
});
ifdescribe(process.platform !== 'darwin' && !isWayland)('disabling parent windows', () => {
it('can disable and enable a window', () => {
const w = new BrowserWindow({ show: false });
w.setEnabled(false);
expect(w.isEnabled()).to.be.false('w.isEnabled()');
w.setEnabled(true);
expect(w.isEnabled()).to.be.true('!w.isEnabled()');
});
ifit(process.platform !== 'darwin')('can disable and enable a window', () => {
const w = new BrowserWindow({ show: false });
w.setEnabled(false);
expect(w.isEnabled()).to.be.false('w.isEnabled()');
w.setEnabled(true);
expect(w.isEnabled()).to.be.true('!w.isEnabled()');
});
it('disables parent window', () => {
const w = new BrowserWindow({ show: false });
const c = new BrowserWindow({ show: false, parent: w, modal: true });
expect(w.isEnabled()).to.be.true('w.isEnabled');
c.show();
expect(w.isEnabled()).to.be.false('w.isEnabled');
});
ifit(process.platform !== 'darwin')('disables parent window', () => {
const w = new BrowserWindow({ show: false });
const c = new BrowserWindow({ show: false, parent: w, modal: true });
expect(w.isEnabled()).to.be.true('w.isEnabled');
c.show();
expect(w.isEnabled()).to.be.false('w.isEnabled');
});
it('re-enables an enabled parent window when closed', async () => {
const w = new BrowserWindow({ show: false });
const c = new BrowserWindow({ show: false, parent: w, modal: true });
const closed = once(c, 'closed');
c.show();
c.close();
await closed;
expect(w.isEnabled()).to.be.true('w.isEnabled');
});
ifit(process.platform !== 'darwin')('re-enables an enabled parent window when closed', async () => {
const w = new BrowserWindow({ show: false });
const c = new BrowserWindow({ show: false, parent: w, modal: true });
const closed = once(c, 'closed');
c.show();
c.close();
await closed;
expect(w.isEnabled()).to.be.true('w.isEnabled');
});
it('does not re-enable a disabled parent window when closed', async () => {
const w = new BrowserWindow({ show: false });
const c = new BrowserWindow({ show: false, parent: w, modal: true });
const closed = once(c, 'closed');
w.setEnabled(false);
c.show();
c.close();
await closed;
expect(w.isEnabled()).to.be.false('w.isEnabled');
});
ifit(process.platform !== 'darwin')('does not re-enable a disabled parent window when closed', async () => {
const w = new BrowserWindow({ show: false });
const c = new BrowserWindow({ show: false, parent: w, modal: true });
const closed = once(c, 'closed');
w.setEnabled(false);
c.show();
c.close();
await closed;
expect(w.isEnabled()).to.be.false('w.isEnabled');
});
it('disables parent window recursively', () => {
const w = new BrowserWindow({ show: false });
const c = new BrowserWindow({ show: false, parent: w, modal: true });
const c2 = new BrowserWindow({ show: false, parent: w, modal: true });
c.show();
expect(w.isEnabled()).to.be.false('w.isEnabled');
c2.show();
expect(w.isEnabled()).to.be.false('w.isEnabled');
c.destroy();
expect(w.isEnabled()).to.be.false('w.isEnabled');
c2.destroy();
expect(w.isEnabled()).to.be.true('w.isEnabled');
});
ifit(process.platform !== 'darwin')('disables parent window recursively', () => {
const w = new BrowserWindow({ show: false });
const c = new BrowserWindow({ show: false, parent: w, modal: true });
const c2 = new BrowserWindow({ show: false, parent: w, modal: true });
c.show();
expect(w.isEnabled()).to.be.false('w.isEnabled');
c2.show();
expect(w.isEnabled()).to.be.false('w.isEnabled');
c.destroy();
expect(w.isEnabled()).to.be.false('w.isEnabled');
c2.destroy();
expect(w.isEnabled()).to.be.true('w.isEnabled');
});
});
});
@@ -5742,7 +5684,7 @@ describe('BrowserWindow module', () => {
});
});
ifdescribe(process.platform !== 'win32' && !isWayland)('visibleOnAllWorkspaces state', () => {
ifdescribe(process.platform !== 'win32')('visibleOnAllWorkspaces state', () => {
describe('with properties', () => {
it('can be changed', () => {
const w = new BrowserWindow({ show: false });
@@ -6893,6 +6835,20 @@ describe('BrowserWindow module', () => {
});
});
describe('window.webContents.focus()', () => {
afterEach(closeAllWindows);
it('focuses window', async () => {
const w1 = new BrowserWindow({ x: 100, y: 300, width: 300, height: 200 });
w1.loadURL('about:blank');
const w2 = new BrowserWindow({ x: 300, y: 300, width: 300, height: 200 });
w2.loadURL('about:blank');
const w1Focused = once(w1, 'focus');
w1.webContents.focus();
await w1Focused;
expect(w1.webContents.isFocused()).to.be.true('focuses window');
});
});
describe('offscreen rendering', () => {
let w: BrowserWindow;
beforeEach(function () {

View File

@@ -0,0 +1,995 @@
import { BrowserWindow, session, utilityProcess } from 'electron/main';
import { expect } from 'chai';
import { on, once } from 'node:events';
import * as path from 'node:path';
import { ifdescribe } from './lib/spec-helpers';
import { closeAllWindows } from './lib/window-helpers';
const features = process._linkedBinding('electron_common_features');
function getFixturePath (fixtureName: string) {
return path.join(path.resolve(__dirname, 'fixtures', 'api', 'local-ai-handler'), fixtureName);
}
// Await fn and listen for a message of the given type, returning the message once received
// Used to listen for a message triggered as a side effect of fn, where we don't care about the result of fn
async function listenForMessage (aiHandler: Electron.UtilityProcess, messageType: string, fn: () => Promise<void> | void) {
const messages = on(aiHandler, 'message');
await fn();
for await (const [message] of messages) {
if (message.type === messageType) {
return message;
}
}
return null;
}
// Call fn and await a message of the given type, returning the message and the promise returned by fn
// Used to listen for a message triggered as a side effect of fn, where we do care about the result of fn
async function waitForMessage (aiHandler: Electron.UtilityProcess, messageType: string, fn: () => Promise<unknown>) {
let promise: Promise<unknown>;
await listenForMessage(aiHandler, messageType, () => {
promise = fn();
});
return { promise: promise! };
}
ifdescribe(features.isPromptAPIEnabled())('localAIHandler module', () => {
const fixtures = path.resolve(__dirname, 'fixtures');
let w: Electron.BrowserWindow;
async function forkAndRegisterHandler (fixtureName: string) {
const aiHandler = utilityProcess.fork(getFixturePath(fixtureName));
await once(aiHandler, 'spawn');
w.webContents.session.registerLocalAIHandler(aiHandler);
return aiHandler;
}
async function sendControllableMessage (aiHandler: Electron.UtilityProcess, message: unknown) {
const ackEvent = once(aiHandler, 'message');
aiHandler.postMessage(message);
await ackEvent;
}
beforeEach(async () => {
w = new BrowserWindow({
show: false,
webPreferences: {
enableBlinkFeatures: 'AIPromptAPI,AIPromptAPIMultimodalInput'
}
});
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
});
afterEach(() => {
w.webContents.session.registerLocalAIHandler(null);
closeAllWindows();
});
describe('LanguageModel.availability()', () => {
it('is unavailable if invalid value returned', async () => {
await forkAndRegisterHandler('buggy-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('returns "available" when handler reports available', async () => {
await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
});
it('returns "downloadable" when handler reports downloadable', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'downloadable' });
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('downloadable');
});
it('returns "downloading" when handler reports downloading', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'downloading' });
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('downloading');
});
it('returns "unavailable" when handler reports unavailable', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'unavailable' });
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('returns "unavailable" when the availability() promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'reject' });
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('returns "unavailable" if the utility process dies', async () => {
const aiHandler = await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
aiHandler.kill();
await once(aiHandler, 'exit');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('returns "unavailable" if not registered', async () => {
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('returns "unavailable" if registered but utility process has not set handler', async () => {
await forkAndRegisterHandler('no-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('passes options to the availability() call', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-availability', value: 'downloading' });
const options = { expectedInputs: [{ type: 'image' }, { type: 'text', languages: ['en', 'fr'] }], expectedOutputs: [{ type: 'image' }, { type: 'text', languages: ['en', 'fr'] }] };
const message = once(aiHandler, 'message');
await w.webContents.executeJavaScript(`LanguageModel.availability(${JSON.stringify(options)})`);
const [receivedMessage] = await message;
expect(receivedMessage.options).to.deep.equal(options);
expect(receivedMessage.type).to.equal('availability-called');
});
});
describe('LanguageModel.create()', () => {
async function expectRejectedWithError (message: string | RegExp, options?: Object) {
// Unwrap the error message because NotAllowedError won't serialize
if (options) {
await expect(w.webContents.executeJavaScript(`LanguageModel.create(${JSON.stringify(options)}).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
} else {
await expect(w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(message);
}
}
it('rejects if invalid value returned', async () => {
await forkAndRegisterHandler('buggy-language-model.js');
await expectRejectedWithError(/unable to create/);
});
it('rejects when no handler is registered', async () => {
await expectRejectedWithError(/unable to create/);
});
it('rejects when handler promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'reject' });
await expectRejectedWithError(/unable to create/);
});
it('rejects if the utility process dies during creation', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'create-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })');
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejectedWith(/unable to create/);
});
it('rejects if the handler gets unregistered during creation', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'create-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().catch(err => { throw err.message; })');
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejectedWith(/unable to create/);
});
it('creates a LanguageModel instance from a valid handler', async () => {
await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model instanceof LanguageModel)')).to.equal(true);
});
it('passes initialPrompts to create()', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const options = { initialPrompts: [{ role: 'system', content: [{ type: 'text', value: 'You are Electron AI' }] }] };
const message = await listenForMessage(aiHandler, 'create-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create(${JSON.stringify(options)})`);
});
expect(message.options).to.have.property('signal');
delete message.options.signal;
expect(message.options).to.deep.equal({ initialPrompts: options.initialPrompts.map(prompt => ({ ...prompt, prefix: false })) });
});
it('passes expectedInputs and expectedOutputs options', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const options = { expectedInputs: [{ type: 'image' }, { type: 'text', languages: ['en', 'fr'] }], expectedOutputs: [{ type: 'text', languages: ['en', 'fr'] }] };
const message = await listenForMessage(aiHandler, 'create-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create(${JSON.stringify(options)})`);
});
expect(message.options).to.have.property('signal');
delete message.options.signal;
expect(message.options).to.deep.equal(options);
});
it('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-create', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'create-aborted', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create({ signal: AbortSignal.timeout(500) }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
it('exposes contextUsage and contextWindow on the created model', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage, contextWindow: model.contextWindow }))')).to.deep.equal({ contextUsage: 0, contextWindow: 12345 });
});
});
describe('LanguageModel.prompt()', () => {
async function expectRejectedWithError (message: string | RegExp, prompt: string, options?: Object) {
// Unwrap the error message because NotAllowedError won't serialize
if (options) {
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt(${JSON.stringify(prompt)}, ${JSON.stringify(options)})).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
} else {
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt(${JSON.stringify(prompt)})).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
}
}
it('rejects when handler returns an invalid value', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 99 });
await expectRejectedWithError(/error occurred/, 'Test prompt');
});
it('rejects when handler promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'reject' });
await expectRejectedWithError(/error occurred/, 'Test prompt');
});
it('rejects after the model has been destroyed', async () => {
await forkAndRegisterHandler('basic-language-model.js');
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.prompt("Test") }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the utility process dies during prompt', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Test")).catch(err => { throw err.message; })');
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the handler gets unregistered during prompt', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Test")).catch(err => { throw err.message; })');
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('returns a string response from the handler', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('foobar');
});
it('returns a ReadableStream response from the handler', async () => {
await forkAndRegisterHandler('streaming-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('Hello World');
});
it('passes string input to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt(\'hello world\'))');
});
expect(message.input).to.deep.equal([{ role: 'user', content: [{ type: 'text', value: 'hello world' }], prefix: false }]);
});
it('passes LanguageModelMessage[] input to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const input = [{ role: 'user', content: [{ type: 'text', value: 'hello' }] }, { role: 'assistant', content: [{ type: 'text', value: 'hi' }] }];
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt(${JSON.stringify(input)}))`);
});
expect(message.input).to.deep.equal(input.map(msg => ({ ...msg, prefix: false })));
});
it('passes responseConstraint option to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const responseConstraint = { type: 'object', properties: { name: { type: 'string' } } };
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create().then(model => model.prompt('test', { responseConstraint: ${JSON.stringify(responseConstraint)} }))`);
});
expect(message.options.responseConstraint).to.deep.equal(responseConstraint);
});
it('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'prompt-aborted', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("test", { signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
it('updates contextUsage after a prompt', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage }))')).to.deep.equal({ contextUsage: 0 });
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(async (model) => { await model.prompt("hello world"); return { contextUsage: model.contextUsage } })')).to.deep.equal({ contextUsage: 10 });
});
});
describe('LanguageModel.promptStreaming()', () => {
const collectStream = 'async (stream) => { const reader = stream.getReader(); let r = ""; while (true) { const { done, value } = await reader.read(); if (done) return r; r += value; } }';
async function expectRejectedWithError (message: string | RegExp, prompt: string, options?: Object) {
const collectStreamFn = collectStream;
// Unwrap the error message because NotAllowedError won't serialize
if (options) {
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStreamFn}; return collect(model.promptStreaming(${JSON.stringify(prompt)}, ${JSON.stringify(options)})); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
} else {
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStreamFn}; return collect(model.promptStreaming(${JSON.stringify(prompt)})); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(message);
}
}
it('rejects when handler returns an invalid value', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 99 });
await expectRejectedWithError(/error occurred/, 'Test prompt');
});
it('rejects when ReadableStream returns an invalid value', async () => {
await forkAndRegisterHandler('buggy-streaming-language-model.js');
await expectRejectedWithError(/has been destroyed/, 'Test prompt');
});
it('rejects when handler promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'reject' });
await expectRejectedWithError(/error occurred/, 'Test prompt');
});
it('rejects after the model has been destroyed', async () => {
await forkAndRegisterHandler('basic-language-model.js');
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { model.destroy(); const collect = ${collectStream}; return collect(model.promptStreaming("Test")); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the utility process dies during prompt', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
return w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Test")); }).catch(err => { throw err.message; })`);
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the handler gets unregistered during prompt', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
return w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Test")); }).catch(err => { throw err.message; })`);
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('returns a string response from the handler', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Hi")); })`)).to.equal('foobar');
});
it('returns a ReadableStream response from the handler', async () => {
await forkAndRegisterHandler('streaming-language-model.js');
expect(await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("Hi")); })`)).to.equal('Hello World');
});
it('passes string input to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming('hello world')); })`);
});
expect(message.input).to.deep.equal([{ role: 'user', content: [{ type: 'text', value: 'hello world' }], prefix: false }]);
});
it('passes LanguageModelMessage[] input to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const input = [{ role: 'user', content: [{ type: 'text', value: 'hello' }] }, { role: 'assistant', content: [{ type: 'text', value: 'hi' }] }];
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming(${JSON.stringify(input)})); })`);
});
expect(message.input).to.deep.equal(input.map(msg => ({ ...msg, prefix: false })));
});
it('passes responseConstraint option to the handler', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const responseConstraint = { type: 'object', properties: { name: { type: 'string' } } };
const message = await listenForMessage(aiHandler, 'prompt-called', async () => {
await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming('test', { responseConstraint: ${JSON.stringify(responseConstraint)} })); })`);
});
expect(message.options.responseConstraint).to.deep.equal(responseConstraint);
});
it('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'prompt-aborted', async () => {
await expect(w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; return collect(model.promptStreaming("test", { signal: AbortSignal.timeout(500) })); }).catch(err => { throw err.message; })`)).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
it('updates contextUsage after a prompt', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage }))')).to.deep.equal({ contextUsage: 0 });
expect(await w.webContents.executeJavaScript(`LanguageModel.create().then(async (model) => { const collect = ${collectStream}; await collect(model.promptStreaming("hello world")); return { contextUsage: model.contextUsage }; })`)).to.deep.equal({ contextUsage: 10 });
});
});
describe('LanguageModel.append()', () => {
it('rejects when handler promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'reject' });
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/error occurred/);
});
it('rejects after the model has been destroyed', async () => {
await forkAndRegisterHandler('basic-language-model.js');
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.append("Test") }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the utility process dies during append', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'append-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })');
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the handler gets unregistered during append', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'append-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })');
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('appends a message without producing a response', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Test")).catch(err => { throw err.message; })')).to.be.undefined();
});
it('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'append-aborted', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("test", { signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
it('updates contextUsage after append', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage }))')).to.deep.equal({ contextUsage: 0 });
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(async (model) => { await model.append("hello world"); return { contextUsage: model.contextUsage } })')).to.deep.equal({ contextUsage: 5 });
});
});
describe('LanguageModel.measureContextUsage()', () => {
it('rejects if invalid value returned', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'invalid' });
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/usage cannot be calculated/);
});
it('rejects when handler promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'reject' });
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/usage cannot be calculated/);
});
it('rejects after the model has been destroyed', async () => {
await forkAndRegisterHandler('basic-language-model.js');
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.measureContextUsage("Test") }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the utility process dies during call', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'measure-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err?.message ?? "Unknown Error"; })');
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejected();
});
it('rejects if the handler gets unregistered during call', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'measure-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Test")).catch(err => { throw err?.message ?? "Unknown Error"; })');
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejected();
});
it('returns the token count for the given input', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("hello world"))')).to.equal(42);
});
// TODO(dsanders11): Upstream Chromium issue prevents this test from passing as
// there's no Mojo connection to disconnect trip abort signal
it.skip('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-measure-response', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'measure-aborted', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("test", { signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
});
describe('LanguageModel.clone()', () => {
it('rejects when clone() returns a non-LanguageModel value', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'invalid' });
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/cannot be cloned/);
});
it('rejects when clone() promise rejects', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'reject' });
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/cannot be cloned/);
});
it('rejects after the original model has been destroyed', async () => {
await forkAndRegisterHandler('basic-language-model.js');
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.clone(); }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
});
it('rejects if the utility process dies during clone', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'clone-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })');
});
aiHandler.kill();
await once(aiHandler, 'exit');
await expect(promise).to.eventually.be.rejectedWith(/cannot be cloned/);
});
it('rejects if the handler gets unregistered during clone', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'wait-for-abort' });
const { promise } = await waitForMessage(aiHandler, 'clone-called', () => {
return w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).catch(err => { throw err.message; })');
});
w.webContents.session.registerLocalAIHandler(null);
await expect(promise).to.eventually.be.rejectedWith(/cannot be cloned/);
});
it('returns a new LanguageModel instance', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).then(cloned => cloned instanceof LanguageModel)')).to.equal(true);
});
it('preserves context from the original model', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript(`
LanguageModel.create().then(async (model) => {
await model.prompt("hello");
const cloned = await model.clone();
return { contextUsage: cloned.contextUsage, contextWindow: cloned.contextWindow };
})
`)).to.deep.equal({ contextUsage: 10, contextWindow: 12345 });
});
it('plumbs the abort signal through', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-clone-response', value: 'wait-for-abort' });
const message = await listenForMessage(aiHandler, 'clone-aborted', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone({ signal: AbortSignal.timeout(500) })).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/signal timed out/);
});
expect(message).not.null();
});
});
describe('LanguageModel.destroy()', () => {
it('destroys the model', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
const message = await listenForMessage(aiHandler, 'destroy-called', async () => {
await expect(w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); return model.prompt("Test"); }).catch(err => { throw err.message; })')).to.eventually.be.rejectedWith(/has been destroyed/);
});
expect(message).not.null();
});
it('aborts any in-progress prompt calls', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-prompt-response', value: 'wait-for-abort' });
await w.webContents.executeJavaScript('LanguageModel.create().then(model => { window._model = model; })');
const { promise } = await waitForMessage(aiHandler, 'prompt-called', () => {
return w.webContents.executeJavaScript('window._model.prompt("Test").catch(err => { throw err.message; })');
});
const message = await listenForMessage(aiHandler, 'prompt-aborted', async () => {
await w.webContents.executeJavaScript('window._model.destroy()');
});
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
expect(message).not.null();
});
it('aborts any in-progress append calls', async () => {
const aiHandler = await forkAndRegisterHandler('controllable-language-model.js');
await sendControllableMessage(aiHandler, { command: 'set-append-response', value: 'wait-for-abort' });
await w.webContents.executeJavaScript('LanguageModel.create().then(model => { window._model = model; })');
const { promise } = await waitForMessage(aiHandler, 'append-called', () => {
return w.webContents.executeJavaScript('window._model.append("Test").catch(err => { throw err.message; })');
});
const message = await listenForMessage(aiHandler, 'append-aborted', async () => {
await w.webContents.executeJavaScript('window._model.destroy()');
});
await w.webContents.executeJavaScript('window._model.destroy()');
await expect(promise).to.eventually.be.rejectedWith(/has been destroyed/);
expect(message).not.null();
});
it('can be called multiple times without error', async () => {
await forkAndRegisterHandler('basic-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => { model.destroy(); model.destroy(); model.destroy(); return true; })')).to.equal(true);
});
});
describe('setPromptAPIHandler()', () => {
it('rejects if handler returns a promise', async () => {
await forkAndRegisterHandler('promise-handler-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('rejects if handler returns a non-class value', async () => {
await forkAndRegisterHandler('non-class-handler-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('rejects if handler returns a class not extending LanguageModel', async () => {
await forkAndRegisterHandler('non-language-model-handler.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
it('receives webContentsId in the details object', async () => {
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
const message = await listenForMessage(aiHandler, 'handler-called', async () => {
await w.webContents.executeJavaScript('LanguageModel.availability()');
});
expect(message.details).to.have.property('webContentsId', w.webContents.id);
});
it('receives securityOrigin in the details object', async () => {
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
const message = await listenForMessage(aiHandler, 'handler-called', async () => {
await w.webContents.executeJavaScript('LanguageModel.availability()');
});
expect(message.details).to.have.property('securityOrigin');
expect(message.details.securityOrigin).to.be.a('string').and.not.be.empty();
});
it('is called once per webContentsId and securityOrigin pair', async () => {
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
const message = await listenForMessage(aiHandler, 'handler-called', async () => {
await w.webContents.executeJavaScript('LanguageModel.availability()');
});
expect(message.callCount).to.equal(1);
// Calling availability again should not trigger the handler again
await w.webContents.executeJavaScript('LanguageModel.availability()');
// Create a second window with the same session - should trigger handler again (different webContentsId)
const w2 = new BrowserWindow({
show: false,
webPreferences: {
session: w.webContents.session,
enableBlinkFeatures: 'AIPromptAPI'
}
});
await w2.loadFile(path.join(fixtures, 'api', 'blank.html'));
const message2 = await listenForMessage(aiHandler, 'handler-called', async () => {
await w2.webContents.executeJavaScript('LanguageModel.availability()');
});
expect(message2.callCount).to.equal(2);
});
it('can be cleared by calling with null', async () => {
const aiHandler = await forkAndRegisterHandler('handler-details-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
// Clear the handler inside the utility process
await sendControllableMessage(aiHandler, { command: 'clear-handler' });
// Existing Prompt API bindings should still work until the page is reloaded
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
// Load a new page to get a fresh Prompt API binding
await w.loadFile(path.join(fixtures, 'api', 'blank.html'));
// Should be unavailable since setPromptAPIHandler(null) was called in the utility process
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
});
});
describe('LanguageModel base class', () => {
it('provides default no-op implementations for all methods', async () => {
await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.append("Hi"))')).to.be.undefined();
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.measureContextUsage("Hi"))')).to.equal(0);
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model.clone()).then(cloned => cloned instanceof LanguageModel)')).to.equal(true);
});
it('can use the base LanguageModel class directly without subclassing', async () => {
await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => model instanceof LanguageModel)')).to.equal(true);
expect(await w.webContents.executeJavaScript('LanguageModel.create().then(model => ({ contextUsage: model.contextUsage, contextWindow: model.contextWindow }))')).to.deep.equal({ contextUsage: 0, contextWindow: 0 });
});
});
describe('session isolation', () => {
it('applies to all windows using the same session', async () => {
await forkAndRegisterHandler('default-language-model.js');
expect(await w.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
const w2 = new BrowserWindow({
show: false,
webPreferences: {
enableBlinkFeatures: 'AIPromptAPI'
}
});
await w2.loadFile(path.join(fixtures, 'api', 'blank.html'));
expect(await w2.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
});
it('different sessions can use different handler processes', async () => {
const ses1 = session.fromPartition('ai-isolation-1');
const ses2 = session.fromPartition('ai-isolation-2');
const w1 = new BrowserWindow({
show: false,
webPreferences: {
session: ses1,
enableBlinkFeatures: 'AIPromptAPI'
}
});
const w2 = new BrowserWindow({
show: false,
webPreferences: {
session: ses2,
enableBlinkFeatures: 'AIPromptAPI'
}
});
await Promise.all([
w1.loadFile(path.join(fixtures, 'api', 'blank.html')),
w2.loadFile(path.join(fixtures, 'api', 'blank.html'))
]);
const aiHandler1 = utilityProcess.fork(getFixturePath('basic-language-model.js'));
await once(aiHandler1, 'spawn');
ses1.registerLocalAIHandler(aiHandler1);
const aiHandler2 = utilityProcess.fork(getFixturePath('default-language-model.js'));
await once(aiHandler2, 'spawn');
ses2.registerLocalAIHandler(aiHandler2);
try {
// basic-language-model returns 'foobar'
expect(await w1.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('foobar');
// default-language-model returns ''
expect(await w2.webContents.executeJavaScript('LanguageModel.create().then(model => model.prompt("Hi"))')).to.equal('');
} finally {
ses1.registerLocalAIHandler(null);
ses2.registerLocalAIHandler(null);
}
});
it('clearing one session handler does not affect another', async () => {
const ses1 = session.fromPartition('ai-isolation-clear-1');
const ses2 = session.fromPartition('ai-isolation-clear-2');
const w1 = new BrowserWindow({
show: false,
webPreferences: {
session: ses1,
enableBlinkFeatures: 'AIPromptAPI'
}
});
const w2 = new BrowserWindow({
show: false,
webPreferences: {
session: ses2,
enableBlinkFeatures: 'AIPromptAPI'
}
});
await Promise.all([
w1.loadFile(path.join(fixtures, 'api', 'blank.html')),
w2.loadFile(path.join(fixtures, 'api', 'blank.html'))
]);
const aiHandler1 = utilityProcess.fork(getFixturePath('basic-language-model.js'));
await once(aiHandler1, 'spawn');
ses1.registerLocalAIHandler(aiHandler1);
const aiHandler2 = utilityProcess.fork(getFixturePath('basic-language-model.js'));
await once(aiHandler2, 'spawn');
ses2.registerLocalAIHandler(aiHandler2);
try {
// Both should be available
expect(await w1.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
expect(await w2.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
// Clear handler for session 1
ses1.registerLocalAIHandler(null);
// Session 1 should be unavailable, session 2 should still be available
expect(await w1.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('unavailable');
expect(await w2.webContents.executeJavaScript('LanguageModel.availability()')).to.equal('available');
} finally {
ses1.registerLocalAIHandler(null);
ses2.registerLocalAIHandler(null);
}
});
});
});

Some files were not shown because too many files have changed in this diff Show More