mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-14 08:45:12 -05:00
Compare commits
1 Commits
fix/copilo
...
feat/opena
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
889b4e4152 |
@@ -32,6 +32,14 @@ from backend.data.model import (
|
|||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.logging import TruncatedLogger
|
from backend.util.logging import TruncatedLogger
|
||||||
|
from backend.util.openai_responses import (
|
||||||
|
convert_tools_to_responses_format,
|
||||||
|
extract_responses_content,
|
||||||
|
extract_responses_reasoning,
|
||||||
|
extract_responses_tool_calls,
|
||||||
|
extract_usage,
|
||||||
|
requires_responses_api,
|
||||||
|
)
|
||||||
from backend.util.prompt import compress_context, estimate_token_count
|
from backend.util.prompt import compress_context, estimate_token_count
|
||||||
from backend.util.text import TextFormatter
|
from backend.util.text import TextFormatter
|
||||||
|
|
||||||
@@ -659,38 +667,72 @@ async def llm_call(
|
|||||||
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
max_tokens = max(min(available_tokens, model_max_output, user_max), 1)
|
||||||
|
|
||||||
if provider == "openai":
|
if provider == "openai":
|
||||||
tools_param = tools if tools else openai.NOT_GIVEN
|
|
||||||
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
|
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
|
||||||
response_format = None
|
|
||||||
|
|
||||||
parallel_tool_calls = get_parallel_tool_calls_param(
|
# Check if this model requires the Responses API (reasoning models: o1, o3, etc.)
|
||||||
llm_model, parallel_tool_calls
|
if requires_responses_api(llm_model.value):
|
||||||
)
|
# Use responses.create for reasoning models
|
||||||
|
tools_converted = (
|
||||||
|
convert_tools_to_responses_format(tools) if tools else None
|
||||||
|
)
|
||||||
|
|
||||||
if force_json_output:
|
response = await oai_client.responses.create(
|
||||||
response_format = {"type": "json_object"}
|
model=llm_model.value,
|
||||||
|
input=prompt, # type: ignore
|
||||||
|
tools=tools_converted, # type: ignore
|
||||||
|
max_output_tokens=max_tokens,
|
||||||
|
store=False, # Don't persist conversations
|
||||||
|
)
|
||||||
|
|
||||||
response = await oai_client.chat.completions.create(
|
tool_calls = extract_responses_tool_calls(response)
|
||||||
model=llm_model.value,
|
reasoning = extract_responses_reasoning(response)
|
||||||
messages=prompt, # type: ignore
|
content = extract_responses_content(response)
|
||||||
response_format=response_format, # type: ignore
|
prompt_tokens, completion_tokens = extract_usage(response, True)
|
||||||
max_completion_tokens=max_tokens,
|
|
||||||
tools=tools_param, # type: ignore
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_calls = extract_openai_tool_calls(response)
|
return LLMResponse(
|
||||||
reasoning = extract_openai_reasoning(response)
|
raw_response=response,
|
||||||
|
prompt=prompt,
|
||||||
|
response=content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
reasoning=reasoning,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use chat.completions.create for standard models
|
||||||
|
tools_param = tools if tools else openai.NOT_GIVEN
|
||||||
|
response_format = None
|
||||||
|
|
||||||
return LLMResponse(
|
parallel_tool_calls = get_parallel_tool_calls_param(
|
||||||
raw_response=response.choices[0].message,
|
llm_model, parallel_tool_calls
|
||||||
prompt=prompt,
|
)
|
||||||
response=response.choices[0].message.content or "",
|
|
||||||
tool_calls=tool_calls,
|
if force_json_output:
|
||||||
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
response_format = {"type": "json_object"}
|
||||||
completion_tokens=response.usage.completion_tokens if response.usage else 0,
|
|
||||||
reasoning=reasoning,
|
response = await oai_client.chat.completions.create(
|
||||||
)
|
model=llm_model.value,
|
||||||
|
messages=prompt, # type: ignore
|
||||||
|
response_format=response_format, # type: ignore
|
||||||
|
max_completion_tokens=max_tokens,
|
||||||
|
tools=tools_param, # type: ignore
|
||||||
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_calls = extract_openai_tool_calls(response)
|
||||||
|
reasoning = extract_openai_reasoning(response)
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
raw_response=response.choices[0].message,
|
||||||
|
prompt=prompt,
|
||||||
|
response=response.choices[0].message.content or "",
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
prompt_tokens=response.usage.prompt_tokens if response.usage else 0,
|
||||||
|
completion_tokens=(
|
||||||
|
response.usage.completion_tokens if response.usage else 0
|
||||||
|
),
|
||||||
|
reasoning=reasoning,
|
||||||
|
)
|
||||||
elif provider == "anthropic":
|
elif provider == "anthropic":
|
||||||
|
|
||||||
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
|
an_tools = convert_openai_tool_fmt_to_anthropic(tools)
|
||||||
|
|||||||
185
autogpt_platform/backend/backend/util/openai_responses.py
Normal file
185
autogpt_platform/backend/backend/util/openai_responses.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""Helpers for OpenAI Responses API migration.
|
||||||
|
|
||||||
|
This module provides utilities for conditionally using OpenAI's Responses API
|
||||||
|
instead of Chat Completions for reasoning models (o1, o3, etc.) that require it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
# Exact model identifiers that require the Responses API.
|
||||||
|
# Use exact matching to avoid false positives on future models.
|
||||||
|
# NOTE: Update this set when OpenAI releases new reasoning models.
|
||||||
|
REASONING_MODELS = frozenset(
|
||||||
|
{
|
||||||
|
# O1 family
|
||||||
|
"o1",
|
||||||
|
"o1-mini",
|
||||||
|
"o1-preview",
|
||||||
|
"o1-2024-12-17",
|
||||||
|
# O3 family
|
||||||
|
"o3",
|
||||||
|
"o3-mini",
|
||||||
|
"o3-2025-04-16",
|
||||||
|
"o3-mini-2025-01-31",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def requires_responses_api(model: str) -> bool:
|
||||||
|
"""Check if model requires the Responses API (exact match).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model identifier string (e.g., "o3-mini", "gpt-4o")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the model requires responses.create, False otherwise
|
||||||
|
"""
|
||||||
|
return model in REASONING_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
def convert_tools_to_responses_format(tools: list[dict] | None) -> list[dict]:
|
||||||
|
"""Convert Chat Completions tool format to Responses API format.
|
||||||
|
|
||||||
|
The Responses API uses internally-tagged polymorphism (flatter structure)
|
||||||
|
and functions are strict by default.
|
||||||
|
|
||||||
|
Chat Completions format:
|
||||||
|
{"type": "function", "function": {"name": "...", "parameters": {...}}}
|
||||||
|
|
||||||
|
Responses API format:
|
||||||
|
{"type": "function", "name": "...", "parameters": {...}}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tools: List of tools in Chat Completions format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tools in Responses API format
|
||||||
|
"""
|
||||||
|
if not tools:
|
||||||
|
return []
|
||||||
|
|
||||||
|
converted = []
|
||||||
|
for tool in tools:
|
||||||
|
if tool.get("type") == "function":
|
||||||
|
func = tool.get("function", {})
|
||||||
|
converted.append(
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"name": func.get("name"),
|
||||||
|
"description": func.get("description"),
|
||||||
|
"parameters": func.get("parameters"),
|
||||||
|
# Note: strict=True is default in Responses API
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Pass through non-function tools as-is
|
||||||
|
converted.append(tool)
|
||||||
|
return converted
|
||||||
|
|
||||||
|
|
||||||
|
def extract_responses_tool_calls(response: Any) -> list[dict] | None:
|
||||||
|
"""Extract tool calls from Responses API response.
|
||||||
|
|
||||||
|
The Responses API returns tool calls as separate items in the output array
|
||||||
|
with type="function_call".
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The Responses API response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool calls in a normalized format, or None if no tool calls
|
||||||
|
"""
|
||||||
|
tool_calls = []
|
||||||
|
for item in response.output:
|
||||||
|
if getattr(item, "type", None) == "function_call":
|
||||||
|
tool_calls.append(
|
||||||
|
{
|
||||||
|
"id": item.call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": item.name,
|
||||||
|
"arguments": item.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return tool_calls if tool_calls else None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_usage(response: Any, is_responses_api: bool) -> tuple[int, int]:
|
||||||
|
"""Extract token usage from either API response.
|
||||||
|
|
||||||
|
The Responses API uses different field names for token counts:
|
||||||
|
- Chat Completions: prompt_tokens, completion_tokens
|
||||||
|
- Responses API: input_tokens, output_tokens
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The API response object
|
||||||
|
is_responses_api: True if response is from Responses API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (prompt_tokens, completion_tokens)
|
||||||
|
"""
|
||||||
|
if not response.usage:
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
if is_responses_api:
|
||||||
|
# Responses API uses different field names
|
||||||
|
return (
|
||||||
|
getattr(response.usage, "input_tokens", 0),
|
||||||
|
getattr(response.usage, "output_tokens", 0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Chat Completions API
|
||||||
|
return (
|
||||||
|
getattr(response.usage, "prompt_tokens", 0),
|
||||||
|
getattr(response.usage, "completion_tokens", 0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_responses_content(response: Any) -> str:
|
||||||
|
"""Extract text content from Responses API response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The Responses API response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The text content from the response, or empty string if none
|
||||||
|
"""
|
||||||
|
# The SDK provides a helper property
|
||||||
|
if hasattr(response, "output_text"):
|
||||||
|
return response.output_text or ""
|
||||||
|
|
||||||
|
# Fallback: manually extract from output items
|
||||||
|
for item in response.output:
|
||||||
|
if getattr(item, "type", None) == "message":
|
||||||
|
for content in getattr(item, "content", []):
|
||||||
|
if getattr(content, "type", None) == "output_text":
|
||||||
|
return getattr(content, "text", "")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_responses_reasoning(response: Any) -> str | None:
|
||||||
|
"""Extract reasoning content from Responses API response.
|
||||||
|
|
||||||
|
Reasoning models return their reasoning process in the response,
|
||||||
|
which can be useful for debugging or display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The Responses API response object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reasoning text, or None if not present
|
||||||
|
"""
|
||||||
|
for item in response.output:
|
||||||
|
if getattr(item, "type", None) == "reasoning":
|
||||||
|
# Reasoning items may have summary or content
|
||||||
|
summary = getattr(item, "summary", [])
|
||||||
|
if summary:
|
||||||
|
# Join summary items if present
|
||||||
|
texts = []
|
||||||
|
for s in summary:
|
||||||
|
if hasattr(s, "text"):
|
||||||
|
texts.append(s.text)
|
||||||
|
if texts:
|
||||||
|
return "\n".join(texts)
|
||||||
|
return None
|
||||||
155
autogpt_platform/backend/backend/util/openai_responses_test.py
Normal file
155
autogpt_platform/backend/backend/util/openai_responses_test.py
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
"""Tests for OpenAI Responses API helpers."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from backend.util.openai_responses import (
|
||||||
|
REASONING_MODELS,
|
||||||
|
convert_tools_to_responses_format,
|
||||||
|
requires_responses_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRequiresResponsesApi:
|
||||||
|
"""Tests for the requires_responses_api function."""
|
||||||
|
|
||||||
|
def test_o1_models_require_responses_api(self):
|
||||||
|
"""O1 family models should require the Responses API."""
|
||||||
|
assert requires_responses_api("o1") is True
|
||||||
|
assert requires_responses_api("o1-mini") is True
|
||||||
|
assert requires_responses_api("o1-preview") is True
|
||||||
|
assert requires_responses_api("o1-2024-12-17") is True
|
||||||
|
|
||||||
|
def test_o3_models_require_responses_api(self):
|
||||||
|
"""O3 family models should require the Responses API."""
|
||||||
|
assert requires_responses_api("o3") is True
|
||||||
|
assert requires_responses_api("o3-mini") is True
|
||||||
|
assert requires_responses_api("o3-2025-04-16") is True
|
||||||
|
assert requires_responses_api("o3-mini-2025-01-31") is True
|
||||||
|
|
||||||
|
def test_gpt_models_do_not_require_responses_api(self):
|
||||||
|
"""GPT models should NOT require the Responses API."""
|
||||||
|
assert requires_responses_api("gpt-4o") is False
|
||||||
|
assert requires_responses_api("gpt-4o-mini") is False
|
||||||
|
assert requires_responses_api("gpt-4-turbo") is False
|
||||||
|
assert requires_responses_api("gpt-3.5-turbo") is False
|
||||||
|
assert requires_responses_api("gpt-5") is False
|
||||||
|
assert requires_responses_api("gpt-5-mini") is False
|
||||||
|
|
||||||
|
def test_other_models_do_not_require_responses_api(self):
|
||||||
|
"""Other provider models should NOT require the Responses API."""
|
||||||
|
assert requires_responses_api("claude-3-opus") is False
|
||||||
|
assert requires_responses_api("llama-3.3-70b") is False
|
||||||
|
assert requires_responses_api("gemini-pro") is False
|
||||||
|
|
||||||
|
def test_empty_string_does_not_require_responses_api(self):
|
||||||
|
"""Empty string should not require the Responses API."""
|
||||||
|
assert requires_responses_api("") is False
|
||||||
|
|
||||||
|
def test_exact_matching_no_false_positives(self):
|
||||||
|
"""Should not match models that just start with 'o1' or 'o3'."""
|
||||||
|
# These are hypothetical models that start with o1/o3 but aren't
|
||||||
|
# actually reasoning models
|
||||||
|
assert requires_responses_api("o1-turbo-hypothetical") is False
|
||||||
|
assert requires_responses_api("o3-fast-hypothetical") is False
|
||||||
|
assert requires_responses_api("o100") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertToolsToResponsesFormat:
|
||||||
|
"""Tests for the convert_tools_to_responses_format function."""
|
||||||
|
|
||||||
|
def test_empty_tools_returns_empty_list(self):
|
||||||
|
"""Empty or None tools should return empty list."""
|
||||||
|
assert convert_tools_to_responses_format(None) == []
|
||||||
|
assert convert_tools_to_responses_format([]) == []
|
||||||
|
|
||||||
|
def test_converts_function_tool_format(self):
|
||||||
|
"""Should convert Chat Completions function format to Responses format."""
|
||||||
|
chat_completions_tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = convert_tools_to_responses_format(chat_completions_tools)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["type"] == "function"
|
||||||
|
assert result[0]["name"] == "get_weather"
|
||||||
|
assert result[0]["description"] == "Get the weather in a location"
|
||||||
|
assert result[0]["parameters"] == {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
}
|
||||||
|
# Should not have nested "function" key
|
||||||
|
assert "function" not in result[0]
|
||||||
|
|
||||||
|
def test_handles_multiple_tools(self):
|
||||||
|
"""Should handle multiple tools."""
|
||||||
|
chat_completions_tools = [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "tool_1",
|
||||||
|
"description": "First tool",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "tool_2",
|
||||||
|
"description": "Second tool",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = convert_tools_to_responses_format(chat_completions_tools)
|
||||||
|
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["name"] == "tool_1"
|
||||||
|
assert result[1]["name"] == "tool_2"
|
||||||
|
|
||||||
|
def test_passes_through_non_function_tools(self):
|
||||||
|
"""Non-function tools should be passed through as-is."""
|
||||||
|
tools = [{"type": "web_search", "config": {"enabled": True}}]
|
||||||
|
|
||||||
|
result = convert_tools_to_responses_format(tools)
|
||||||
|
|
||||||
|
assert result == tools
|
||||||
|
|
||||||
|
|
||||||
|
class TestReasoningModelsSet:
|
||||||
|
"""Tests for the REASONING_MODELS constant."""
|
||||||
|
|
||||||
|
def test_reasoning_models_is_frozenset(self):
|
||||||
|
"""REASONING_MODELS should be a frozenset (immutable)."""
|
||||||
|
assert isinstance(REASONING_MODELS, frozenset)
|
||||||
|
|
||||||
|
def test_contains_expected_models(self):
|
||||||
|
"""Should contain all expected reasoning models."""
|
||||||
|
expected = {
|
||||||
|
"o1",
|
||||||
|
"o1-mini",
|
||||||
|
"o1-preview",
|
||||||
|
"o1-2024-12-17",
|
||||||
|
"o3",
|
||||||
|
"o3-mini",
|
||||||
|
"o3-2025-04-16",
|
||||||
|
"o3-mini-2025-01-31",
|
||||||
|
}
|
||||||
|
assert expected.issubset(REASONING_MODELS)
|
||||||
@@ -10,7 +10,7 @@ import {
|
|||||||
} from "@/components/ui/tooltip";
|
} from "@/components/ui/tooltip";
|
||||||
import { cn } from "@/lib/utils";
|
import { cn } from "@/lib/utils";
|
||||||
import { cjk } from "@streamdown/cjk";
|
import { cjk } from "@streamdown/cjk";
|
||||||
import { code } from "@/lib/streamdown-code-singleton";
|
import { code } from "@streamdown/code";
|
||||||
import { math } from "@streamdown/math";
|
import { math } from "@streamdown/math";
|
||||||
import { mermaid } from "@streamdown/mermaid";
|
import { mermaid } from "@streamdown/mermaid";
|
||||||
import type { UIMessage } from "ai";
|
import type { UIMessage } from "ai";
|
||||||
|
|||||||
@@ -1,237 +0,0 @@
|
|||||||
/**
|
|
||||||
* Custom Streamdown code plugin with proper shiki singleton.
|
|
||||||
*
|
|
||||||
* Fixes SENTRY-1051: "@streamdown/code creates a new shiki highlighter per language,
|
|
||||||
* causing "10 instances created" warnings and memory bloat.
|
|
||||||
*
|
|
||||||
* This plugin creates ONE highlighter and loads languages dynamically.
|
|
||||||
*/
|
|
||||||
|
|
||||||
import {
|
|
||||||
createHighlighter,
|
|
||||||
bundledLanguages,
|
|
||||||
type Highlighter,
|
|
||||||
type BundledLanguage,
|
|
||||||
type BundledTheme,
|
|
||||||
} from "shiki";
|
|
||||||
|
|
||||||
// Types matching streamdown's expected interface
|
|
||||||
interface HighlightToken {
|
|
||||||
content: string;
|
|
||||||
color?: string;
|
|
||||||
bgColor?: string;
|
|
||||||
htmlStyle?: Record<string, string>;
|
|
||||||
htmlAttrs?: Record<string, string>;
|
|
||||||
offset?: number;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface HighlightResult {
|
|
||||||
tokens: HighlightToken[][];
|
|
||||||
fg?: string;
|
|
||||||
bg?: string;
|
|
||||||
}
|
|
||||||
|
|
||||||
interface HighlightOptions {
|
|
||||||
code: string;
|
|
||||||
language: BundledLanguage;
|
|
||||||
themes: [string, string];
|
|
||||||
}
|
|
||||||
|
|
||||||
interface CodeHighlighterPlugin {
|
|
||||||
name: "shiki";
|
|
||||||
type: "code-highlighter";
|
|
||||||
highlight: (
|
|
||||||
options: HighlightOptions,
|
|
||||||
callback?: (result: HighlightResult) => void
|
|
||||||
) => HighlightResult | null;
|
|
||||||
supportsLanguage: (language: BundledLanguage) => boolean;
|
|
||||||
getSupportedLanguages: () => BundledLanguage[];
|
|
||||||
getThemes: () => [BundledTheme, BundledTheme];
|
|
||||||
}
|
|
||||||
|
|
||||||
// Singleton state
|
|
||||||
let highlighterPromise: Promise<Highlighter> | null = null;
|
|
||||||
let highlighterInstance: Highlighter | null = null;
|
|
||||||
const loadedLanguages = new Set<string>();
|
|
||||||
const pendingLanguages = new Map<string, Promise<void>>();
|
|
||||||
|
|
||||||
// Result cache (same as @streamdown/code)
|
|
||||||
const resultCache = new Map<string, HighlightResult>();
|
|
||||||
const pendingCallbacks = new Map<string, Set<(result: HighlightResult) => void>>();
|
|
||||||
|
|
||||||
// All supported languages
|
|
||||||
const supportedLanguages = new Set(Object.keys(bundledLanguages));
|
|
||||||
|
|
||||||
// Cache key for results
|
|
||||||
function getCacheKey(code: string, language: string, themes: [string, string]): string {
|
|
||||||
const prefix = code.slice(0, 100);
|
|
||||||
const suffix = code.length > 100 ? code.slice(-100) : "";
|
|
||||||
return `${language}:${themes[0]}:${themes[1]}:${code.length}:${prefix}:${suffix}`;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get or create the singleton highlighter
|
|
||||||
async function getHighlighter(themes: [string, string]): Promise<Highlighter> {
|
|
||||||
if (highlighterInstance) {
|
|
||||||
return highlighterInstance;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!highlighterPromise) {
|
|
||||||
highlighterPromise = createHighlighter({
|
|
||||||
themes: themes as [BundledTheme, BundledTheme],
|
|
||||||
// Start with common languages pre-loaded for faster first render
|
|
||||||
langs: ["javascript", "typescript", "python", "json", "html", "css", "bash", "markdown"],
|
|
||||||
}).then((h: Highlighter) => {
|
|
||||||
highlighterInstance = h;
|
|
||||||
["javascript", "typescript", "python", "json", "html", "css", "bash", "markdown"].forEach(
|
|
||||||
(l) => loadedLanguages.add(l)
|
|
||||||
);
|
|
||||||
return h;
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return highlighterPromise;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load a language dynamically
|
|
||||||
async function ensureLanguageLoaded(
|
|
||||||
highlighter: Highlighter,
|
|
||||||
language: string
|
|
||||||
): Promise<void> {
|
|
||||||
if (loadedLanguages.has(language)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (pendingLanguages.has(language)) {
|
|
||||||
return pendingLanguages.get(language);
|
|
||||||
}
|
|
||||||
|
|
||||||
const loadPromise = highlighter
|
|
||||||
.loadLanguage(language as BundledLanguage)
|
|
||||||
.then(() => {
|
|
||||||
loadedLanguages.add(language);
|
|
||||||
pendingLanguages.delete(language);
|
|
||||||
})
|
|
||||||
.catch((err: Error) => {
|
|
||||||
console.warn(`[streamdown-code-singleton] Failed to load language: ${language}`, err);
|
|
||||||
pendingLanguages.delete(language);
|
|
||||||
});
|
|
||||||
|
|
||||||
pendingLanguages.set(language, loadPromise);
|
|
||||||
return loadPromise;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shiki token types
|
|
||||||
interface ShikiToken {
|
|
||||||
content: string;
|
|
||||||
color?: string;
|
|
||||||
htmlStyle?: Record<string, string>;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert shiki tokens to streamdown format
|
|
||||||
function convertTokens(
|
|
||||||
shikiResult: ReturnType<Highlighter["codeToTokens"]>
|
|
||||||
): HighlightResult {
|
|
||||||
return {
|
|
||||||
tokens: shikiResult.tokens.map((line: ShikiToken[]) =>
|
|
||||||
line.map((token: ShikiToken) => ({
|
|
||||||
content: token.content,
|
|
||||||
color: token.color,
|
|
||||||
htmlStyle: token.htmlStyle,
|
|
||||||
}))
|
|
||||||
),
|
|
||||||
fg: shikiResult.fg,
|
|
||||||
bg: shikiResult.bg,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface CodePluginOptions {
|
|
||||||
themes?: [BundledTheme, BundledTheme];
|
|
||||||
}
|
|
||||||
|
|
||||||
export function createCodePlugin(
|
|
||||||
options: CodePluginOptions = {}
|
|
||||||
): CodeHighlighterPlugin {
|
|
||||||
const themes = options.themes ?? ["github-light", "github-dark"];
|
|
||||||
|
|
||||||
return {
|
|
||||||
name: "shiki",
|
|
||||||
type: "code-highlighter",
|
|
||||||
|
|
||||||
supportsLanguage(language: BundledLanguage): boolean {
|
|
||||||
return supportedLanguages.has(language);
|
|
||||||
},
|
|
||||||
|
|
||||||
getSupportedLanguages(): BundledLanguage[] {
|
|
||||||
return Array.from(supportedLanguages) as BundledLanguage[];
|
|
||||||
},
|
|
||||||
|
|
||||||
getThemes(): [BundledTheme, BundledTheme] {
|
|
||||||
return themes as [BundledTheme, BundledTheme];
|
|
||||||
},
|
|
||||||
|
|
||||||
highlight(
|
|
||||||
{ code, language, themes: highlightThemes }: HighlightOptions,
|
|
||||||
callback?: (result: HighlightResult) => void
|
|
||||||
): HighlightResult | null {
|
|
||||||
const cacheKey = getCacheKey(code, language, highlightThemes);
|
|
||||||
|
|
||||||
// Return cached result if available
|
|
||||||
if (resultCache.has(cacheKey)) {
|
|
||||||
return resultCache.get(cacheKey)!;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register callback for async result
|
|
||||||
if (callback) {
|
|
||||||
if (!pendingCallbacks.has(cacheKey)) {
|
|
||||||
pendingCallbacks.set(cacheKey, new Set());
|
|
||||||
}
|
|
||||||
pendingCallbacks.get(cacheKey)!.add(callback);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start async highlighting
|
|
||||||
getHighlighter(highlightThemes)
|
|
||||||
.then(async (highlighter) => {
|
|
||||||
// Ensure language is loaded
|
|
||||||
const lang = supportedLanguages.has(language) ? language : "text";
|
|
||||||
if (lang !== "text") {
|
|
||||||
await ensureLanguageLoaded(highlighter, lang);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Highlight the code
|
|
||||||
const effectiveLang = highlighter.getLoadedLanguages().includes(lang)
|
|
||||||
? lang
|
|
||||||
: "text";
|
|
||||||
|
|
||||||
const shikiResult = highlighter.codeToTokens(code, {
|
|
||||||
lang: effectiveLang,
|
|
||||||
themes: {
|
|
||||||
light: highlightThemes[0] as BundledTheme,
|
|
||||||
dark: highlightThemes[1] as BundledTheme,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
const result = convertTokens(shikiResult);
|
|
||||||
resultCache.set(cacheKey, result);
|
|
||||||
|
|
||||||
// Notify all pending callbacks
|
|
||||||
const callbacks = pendingCallbacks.get(cacheKey);
|
|
||||||
if (callbacks) {
|
|
||||||
for (const cb of callbacks) {
|
|
||||||
cb(result);
|
|
||||||
}
|
|
||||||
pendingCallbacks.delete(cacheKey);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.catch((err) => {
|
|
||||||
console.error("[streamdown-code-singleton] Failed to highlight code:", err);
|
|
||||||
pendingCallbacks.delete(cacheKey);
|
|
||||||
});
|
|
||||||
|
|
||||||
// Return null while async loading
|
|
||||||
return null;
|
|
||||||
},
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-configured plugin with default settings (drop-in replacement for @streamdown/code)
|
|
||||||
export const code = createCodePlugin();
|
|
||||||
Reference in New Issue
Block a user