mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 00:28:31 -05:00
Compare commits
61 Commits
go
...
ntindle/sa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ab3a62995f | ||
|
|
f2e9a8463d | ||
|
|
107148749b | ||
|
|
17f1d33ed3 | ||
|
|
b3a0fc538a | ||
|
|
e818bbf859 | ||
|
|
069ec89691 | ||
|
|
f6ab15db47 | ||
|
|
84d490bcb1 | ||
|
|
80161decb9 | ||
|
|
0bf8edcd96 | ||
|
|
9368956d5d | ||
|
|
2c3bde0c53 | ||
|
|
104b56628e | ||
|
|
b1347a92de | ||
|
|
22ce8e0047 | ||
|
|
5a7193cfb7 | ||
|
|
15ac526eee | ||
|
|
c1f301ab8b | ||
|
|
5f83e354b9 | ||
|
|
70ebf4d58b | ||
|
|
6d0d264d99 | ||
|
|
f32244a112 | ||
|
|
8e24b546a3 | ||
|
|
d4838cdc45 | ||
|
|
acaca35498 | ||
|
|
9ee0825f21 | ||
|
|
5fde0f2c67 | ||
|
|
e3407fdfb4 | ||
|
|
b98e62cdef | ||
|
|
4d82f78f04 | ||
|
|
c5d2586f6c | ||
|
|
589c8d94ec | ||
|
|
136d258a46 | ||
|
|
92bcc39f4d | ||
|
|
5909697215 | ||
|
|
bf34801a74 | ||
|
|
154eccb9af | ||
|
|
14f8a92c20 | ||
|
|
2c07c64ccf | ||
|
|
ef21d359a6 | ||
|
|
f4bd998fa2 | ||
|
|
4ebae90f62 | ||
|
|
09d3768948 | ||
|
|
8c6adaeaa1 | ||
|
|
dabd2e1610 | ||
|
|
b228c4445e | ||
|
|
05c9931c11 | ||
|
|
9198a86c0e | ||
|
|
c8fedf3dad | ||
|
|
0c7e1838cd | ||
|
|
979d80cd17 | ||
|
|
4f7ffd13e4 | ||
|
|
b944e0f6da | ||
|
|
51aaaf6ddc | ||
|
|
3c662af1ba | ||
|
|
17370116f6 | ||
|
|
d15049e9a7 | ||
|
|
da4afd4530 | ||
|
|
7617aa6d1f | ||
|
|
b190e1f2aa |
3
rnd/autogpt_builder/.gitignore
vendored
3
rnd/autogpt_builder/.gitignore
vendored
@@ -34,3 +34,6 @@ yarn-error.log*
|
||||
# typescript
|
||||
*.tsbuildinfo
|
||||
next-env.d.ts
|
||||
|
||||
# Sentry Config File
|
||||
.env.sentry-build-plugin
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { withSentryConfig } from "@sentry/nextjs";
|
||||
import dotenv from "dotenv";
|
||||
|
||||
// Load environment variables
|
||||
@@ -28,4 +29,56 @@ const nextConfig = {
|
||||
},
|
||||
};
|
||||
|
||||
export default nextConfig;
|
||||
export default withSentryConfig(nextConfig, {
|
||||
// For all available options, see:
|
||||
// https://github.com/getsentry/sentry-webpack-plugin#options
|
||||
|
||||
org: "significant-gravitas",
|
||||
project: "builder",
|
||||
|
||||
// Only print logs for uploading source maps in CI
|
||||
silent: !process.env.CI,
|
||||
|
||||
// For all available options, see:
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/manual-setup/
|
||||
|
||||
// Upload a larger set of source maps for prettier stack traces (increases build time)
|
||||
widenClientFileUpload: true,
|
||||
|
||||
// Automatically annotate React components to show their full name in breadcrumbs and session replay
|
||||
reactComponentAnnotation: {
|
||||
enabled: true,
|
||||
},
|
||||
|
||||
// Route browser requests to Sentry through a Next.js rewrite to circumvent ad-blockers.
|
||||
// This can increase your server load as well as your hosting bill.
|
||||
// Note: Check that the configured route will not match with your Next.js middleware, otherwise reporting of client-
|
||||
// side errors will fail.
|
||||
tunnelRoute: "/monitoring",
|
||||
|
||||
// Hides source maps from generated client bundles
|
||||
hideSourceMaps: true,
|
||||
|
||||
// Automatically tree-shake Sentry logger statements to reduce bundle size
|
||||
disableLogger: true,
|
||||
|
||||
// Enables automatic instrumentation of Vercel Cron Monitors. (Does not yet work with App Router route handlers.)
|
||||
// See the following for more information:
|
||||
// https://docs.sentry.io/product/crons/
|
||||
// https://vercel.com/docs/cron-jobs
|
||||
automaticVercelMonitors: true,
|
||||
|
||||
async headers() {
|
||||
return [
|
||||
{
|
||||
source: "/:path*",
|
||||
headers: [
|
||||
{
|
||||
key: "Document-Policy",
|
||||
value: "js-profiling",
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
},
|
||||
});
|
||||
|
||||
@@ -27,6 +27,7 @@
|
||||
"@radix-ui/react-switch": "^1.1.0",
|
||||
"@radix-ui/react-toast": "^1.2.1",
|
||||
"@radix-ui/react-tooltip": "^1.1.2",
|
||||
"@sentry/nextjs": "^8",
|
||||
"@supabase/ssr": "^0.4.0",
|
||||
"@supabase/supabase-js": "^2.45.0",
|
||||
"@tanstack/react-table": "^8.20.5",
|
||||
|
||||
57
rnd/autogpt_builder/sentry.client.config.ts
Normal file
57
rnd/autogpt_builder/sentry.client.config.ts
Normal file
@@ -0,0 +1,57 @@
|
||||
// This file configures the initialization of Sentry on the client.
|
||||
// The config you add here will be used whenever a users loads a page in their browser.
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
// Add optional integrations for additional features
|
||||
integrations: [
|
||||
Sentry.replayIntegration(),
|
||||
Sentry.httpClientIntegration(),
|
||||
Sentry.replayCanvasIntegration(),
|
||||
Sentry.reportingObserverIntegration(),
|
||||
Sentry.browserProfilingIntegration(),
|
||||
// Sentry.feedbackIntegration({
|
||||
// // Additional SDK configuration goes in here, for example:
|
||||
// colorScheme: "system",
|
||||
// }),
|
||||
],
|
||||
|
||||
// Define how likely traces are sampled. Adjust this value in production, or use tracesSampler for greater control.
|
||||
tracesSampleRate: 1,
|
||||
|
||||
// Set `tracePropagationTargets` to control for which URLs trace propagation should be enabled
|
||||
tracePropagationTargets: [
|
||||
"localhost",
|
||||
/^https:\/\/dev\-builder\.agpt\.co\/api/,
|
||||
],
|
||||
|
||||
beforeSend(event, hint) {
|
||||
// Check if it is an exception, and if so, show the report dialog
|
||||
if (event.exception && event.event_id) {
|
||||
Sentry.showReportDialog({ eventId: event.event_id });
|
||||
}
|
||||
return event;
|
||||
},
|
||||
|
||||
// Define how likely Replay events are sampled.
|
||||
// This sets the sample rate to be 10%. You may want this to be 100% while
|
||||
// in development and sample at a lower rate in production
|
||||
replaysSessionSampleRate: 0.1,
|
||||
|
||||
// Define how likely Replay events are sampled when an error occurs.
|
||||
replaysOnErrorSampleRate: 1.0,
|
||||
|
||||
// Setting this option to true will print useful information to the console while you're setting up Sentry.
|
||||
debug: false,
|
||||
|
||||
// Set profilesSampleRate to 1.0 to profile every transaction.
|
||||
// Since profilesSampleRate is relative to tracesSampleRate,
|
||||
// the final profiling rate can be computed as tracesSampleRate * profilesSampleRate
|
||||
// For example, a tracesSampleRate of 0.5 and profilesSampleRate of 0.5 would
|
||||
// result in 25% of transactions being profiled (0.5*0.5=0.25)
|
||||
profilesSampleRate: 1.0,
|
||||
});
|
||||
16
rnd/autogpt_builder/sentry.edge.config.ts
Normal file
16
rnd/autogpt_builder/sentry.edge.config.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
// This file configures the initialization of Sentry for edge features (middleware, edge routes, and so on).
|
||||
// The config you add here will be used whenever one of the edge features is loaded.
|
||||
// Note that this config is unrelated to the Vercel Edge Runtime and is also required when running locally.
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
// Define how likely traces are sampled. Adjust this value in production, or use tracesSampler for greater control.
|
||||
tracesSampleRate: 1,
|
||||
|
||||
// Setting this option to true will print useful information to the console while you're setting up Sentry.
|
||||
debug: false,
|
||||
});
|
||||
23
rnd/autogpt_builder/sentry.server.config.ts
Normal file
23
rnd/autogpt_builder/sentry.server.config.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
// This file configures the initialization of Sentry on the server.
|
||||
// The config you add here will be used whenever the server handles a request.
|
||||
// https://docs.sentry.io/platforms/javascript/guides/nextjs/
|
||||
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
// import { NodeProfilingIntegration } from "@sentry/profiling-node";
|
||||
|
||||
Sentry.init({
|
||||
dsn: "https://fe4e4aa4a283391808a5da396da20159@o4505260022104064.ingest.us.sentry.io/4507946746380288",
|
||||
|
||||
// Define how likely traces are sampled. Adjust this value in production, or use tracesSampler for greater control.
|
||||
tracesSampleRate: 1,
|
||||
|
||||
// Setting this option to true will print useful information to the console while you're setting up Sentry.
|
||||
debug: false,
|
||||
|
||||
// Integrations
|
||||
integrations: [
|
||||
Sentry.anrIntegration(),
|
||||
// NodeProfilingIntegration,
|
||||
// Sentry.fsIntegration(),
|
||||
],
|
||||
});
|
||||
27
rnd/autogpt_builder/src/app/global-error.tsx
Normal file
27
rnd/autogpt_builder/src/app/global-error.tsx
Normal file
@@ -0,0 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import NextError from "next/error";
|
||||
import { useEffect } from "react";
|
||||
|
||||
export default function GlobalError({
|
||||
error,
|
||||
}: {
|
||||
error: Error & { digest?: string };
|
||||
}) {
|
||||
useEffect(() => {
|
||||
Sentry.captureException(error);
|
||||
}, [error]);
|
||||
|
||||
return (
|
||||
<html>
|
||||
<body>
|
||||
{/* `NextError` is the default Next.js error page component. Its type
|
||||
definition requires a `statusCode` prop. However, since the App Router
|
||||
does not expose status codes for errors, we simply pass 0 to render a
|
||||
generic error message. */}
|
||||
<NextError statusCode={0} />
|
||||
</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
@@ -3,6 +3,7 @@ import { revalidatePath } from "next/cache";
|
||||
import { redirect } from "next/navigation";
|
||||
import { createServerClient } from "@/lib/supabase/server";
|
||||
import { z } from "zod";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
const loginFormSchema = z.object({
|
||||
email: z.string().email().min(2).max(64),
|
||||
@@ -10,45 +11,53 @@ const loginFormSchema = z.object({
|
||||
});
|
||||
|
||||
export async function login(values: z.infer<typeof loginFormSchema>) {
|
||||
const supabase = createServerClient();
|
||||
return await Sentry.withServerActionInstrumentation("login", {}, async () => {
|
||||
const supabase = createServerClient();
|
||||
|
||||
if (!supabase) {
|
||||
redirect("/error");
|
||||
}
|
||||
if (!supabase) {
|
||||
redirect("/error");
|
||||
}
|
||||
|
||||
// We are sure that the values are of the correct type because zod validates the form
|
||||
const { data, error } = await supabase.auth.signInWithPassword(values);
|
||||
// We are sure that the values are of the correct type because zod validates the form
|
||||
const { data, error } = await supabase.auth.signInWithPassword(values);
|
||||
|
||||
if (error) {
|
||||
return error.message;
|
||||
}
|
||||
if (error) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
if (data.session) {
|
||||
await supabase.auth.setSession(data.session);
|
||||
}
|
||||
if (data.session) {
|
||||
await supabase.auth.setSession(data.session);
|
||||
}
|
||||
|
||||
revalidatePath("/", "layout");
|
||||
redirect("/profile");
|
||||
revalidatePath("/", "layout");
|
||||
redirect("/profile");
|
||||
});
|
||||
}
|
||||
|
||||
export async function signup(values: z.infer<typeof loginFormSchema>) {
|
||||
const supabase = createServerClient();
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"signup",
|
||||
{},
|
||||
async () => {
|
||||
const supabase = createServerClient();
|
||||
|
||||
if (!supabase) {
|
||||
redirect("/error");
|
||||
}
|
||||
if (!supabase) {
|
||||
redirect("/error");
|
||||
}
|
||||
|
||||
// We are sure that the values are of the correct type because zod validates the form
|
||||
const { data, error } = await supabase.auth.signUp(values);
|
||||
// We are sure that the values are of the correct type because zod validates the form
|
||||
const { data, error } = await supabase.auth.signUp(values);
|
||||
|
||||
if (error) {
|
||||
return error.message;
|
||||
}
|
||||
if (error) {
|
||||
return error.message;
|
||||
}
|
||||
|
||||
if (data.session) {
|
||||
await supabase.auth.setSession(data.session);
|
||||
}
|
||||
if (data.session) {
|
||||
await supabase.auth.setSession(data.session);
|
||||
}
|
||||
|
||||
revalidatePath("/", "layout");
|
||||
redirect("/profile");
|
||||
revalidatePath("/", "layout");
|
||||
redirect("/profile");
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ import "@xyflow/react/dist/style.css";
|
||||
import { CustomNode } from "./CustomNode";
|
||||
import "./flow.css";
|
||||
import { Link } from "@/lib/autogpt-server-api";
|
||||
import { getTypeColor } from "@/lib/utils";
|
||||
import { getTypeColor, filterBlocksByType } from "@/lib/utils";
|
||||
import { history } from "./history";
|
||||
import { CustomEdge } from "./CustomEdge";
|
||||
import ConnectionLine from "./ConnectionLine";
|
||||
@@ -36,14 +36,19 @@ import { SaveControl } from "@/components/edit/control/SaveControl";
|
||||
import { BlocksControl } from "@/components/edit/control/BlocksControl";
|
||||
import {
|
||||
IconPlay,
|
||||
IconUndo2,
|
||||
IconRedo2,
|
||||
IconSquare,
|
||||
IconUndo2,
|
||||
IconOutput,
|
||||
} from "@/components/ui/icons";
|
||||
import { startTutorial } from "./tutorial";
|
||||
import useAgentGraph from "@/hooks/useAgentGraph";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { useRouter, usePathname, useSearchParams } from "next/navigation";
|
||||
import { LogOut } from "lucide-react";
|
||||
import RunnerUIWrapper, {
|
||||
RunnerUIWrapperRef,
|
||||
} from "@/components/RunnerUIWrapper";
|
||||
|
||||
// This is for the history, this is the minimum distance a block must move before it is logged
|
||||
// It helps to prevent spamming the history with small movements especially when pressing on a input in a block
|
||||
@@ -101,6 +106,8 @@ const FlowEditor: React.FC<{
|
||||
// State to control if blocks menu should be pinned open
|
||||
const [pinBlocksPopover, setPinBlocksPopover] = useState(false);
|
||||
|
||||
const runnerUIRef = useRef<RunnerUIWrapperRef>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
|
||||
@@ -550,9 +557,21 @@ const FlowEditor: React.FC<{
|
||||
onClick: handleRedo,
|
||||
},
|
||||
{
|
||||
label: !isRunning ? "Run" : "Stop",
|
||||
label: !savedAgent
|
||||
? "Please save the agent to run"
|
||||
: !isRunning
|
||||
? "Run"
|
||||
: "Stop",
|
||||
icon: !isRunning ? <IconPlay /> : <IconSquare />,
|
||||
onClick: !isRunning ? requestSaveAndRun : requestStopRun,
|
||||
onClick: !isRunning
|
||||
? () => runnerUIRef.current?.runOrOpenInput()
|
||||
: requestStopRun,
|
||||
disabled: !savedAgent,
|
||||
},
|
||||
{
|
||||
label: "Runner Output",
|
||||
icon: <LogOut size={18} strokeWidth={1.8} />,
|
||||
onClick: () => runnerUIRef.current?.openRunnerOutput(),
|
||||
},
|
||||
];
|
||||
|
||||
@@ -588,12 +607,21 @@ const FlowEditor: React.FC<{
|
||||
<SaveControl
|
||||
agentMeta={savedAgent}
|
||||
onSave={(isTemplate) => requestSave(isTemplate ?? false)}
|
||||
agentDescription={agentDescription}
|
||||
onDescriptionChange={setAgentDescription}
|
||||
agentName={agentName}
|
||||
onNameChange={setAgentName}
|
||||
/>
|
||||
</ControlPanel>
|
||||
</ReactFlow>
|
||||
</div>
|
||||
<RunnerUIWrapper
|
||||
ref={runnerUIRef}
|
||||
nodes={nodes}
|
||||
setNodes={setNodes}
|
||||
isRunning={isRunning}
|
||||
requestSaveAndRun={requestSaveAndRun}
|
||||
/>
|
||||
</FlowContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
141
rnd/autogpt_builder/src/components/RunnerUIWrapper.tsx
Normal file
141
rnd/autogpt_builder/src/components/RunnerUIWrapper.tsx
Normal file
@@ -0,0 +1,141 @@
|
||||
import React, {
|
||||
useState,
|
||||
useCallback,
|
||||
forwardRef,
|
||||
useImperativeHandle,
|
||||
} from "react";
|
||||
import RunnerInputUI from "./runner-ui/RunnerInputUI";
|
||||
import RunnerOutputUI from "./runner-ui/RunnerOutputUI";
|
||||
import { Node } from "@xyflow/react";
|
||||
import { filterBlocksByType } from "@/lib/utils";
|
||||
import { BlockIORootSchema } from "@/lib/autogpt-server-api/types";
|
||||
|
||||
interface RunnerUIWrapperProps {
|
||||
nodes: Node[];
|
||||
setNodes: React.Dispatch<React.SetStateAction<Node[]>>;
|
||||
isRunning: boolean;
|
||||
requestSaveAndRun: () => void;
|
||||
}
|
||||
|
||||
export interface RunnerUIWrapperRef {
|
||||
openRunnerInput: () => void;
|
||||
openRunnerOutput: () => void;
|
||||
runOrOpenInput: () => void;
|
||||
}
|
||||
|
||||
const RunnerUIWrapper = forwardRef<RunnerUIWrapperRef, RunnerUIWrapperProps>(
|
||||
({ nodes, setNodes, isRunning, requestSaveAndRun }, ref) => {
|
||||
const [isRunnerInputOpen, setIsRunnerInputOpen] = useState(false);
|
||||
const [isRunnerOutputOpen, setIsRunnerOutputOpen] = useState(false);
|
||||
|
||||
const getBlockInputsAndOutputs = useCallback(() => {
|
||||
const inputBlocks = filterBlocksByType(
|
||||
nodes,
|
||||
(node) => node.data.block_id === "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
);
|
||||
|
||||
const outputBlocks = filterBlocksByType(
|
||||
nodes,
|
||||
(node) => node.data.block_id === "363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
);
|
||||
|
||||
const inputs = inputBlocks.map((node) => ({
|
||||
id: node.id,
|
||||
type: "input" as const,
|
||||
inputSchema: node.data.inputSchema as BlockIORootSchema,
|
||||
hardcodedValues: {
|
||||
name: (node.data.hardcodedValues as any).name || "",
|
||||
description: (node.data.hardcodedValues as any).description || "",
|
||||
value: (node.data.hardcodedValues as any).value,
|
||||
placeholder_values:
|
||||
(node.data.hardcodedValues as any).placeholder_values || [],
|
||||
limit_to_placeholder_values:
|
||||
(node.data.hardcodedValues as any).limit_to_placeholder_values ||
|
||||
false,
|
||||
},
|
||||
}));
|
||||
|
||||
const outputs = outputBlocks.map((node) => ({
|
||||
id: node.id,
|
||||
type: "output" as const,
|
||||
outputSchema: node.data.outputSchema as BlockIORootSchema,
|
||||
hardcodedValues: {
|
||||
name: (node.data.hardcodedValues as any).name || "Output",
|
||||
description:
|
||||
(node.data.hardcodedValues as any).description ||
|
||||
"Output from the agent",
|
||||
value: (node.data.hardcodedValues as any).value,
|
||||
},
|
||||
result: (node.data.executionResults as any)?.at(-1)?.data?.output,
|
||||
}));
|
||||
|
||||
return { inputs, outputs };
|
||||
}, [nodes]);
|
||||
|
||||
const handleInputChange = useCallback(
|
||||
(nodeId: string, field: string, value: string) => {
|
||||
setNodes((nds) =>
|
||||
nds.map((node) => {
|
||||
if (node.id === nodeId) {
|
||||
return {
|
||||
...node,
|
||||
data: {
|
||||
...node.data,
|
||||
hardcodedValues: {
|
||||
...(node.data.hardcodedValues as any),
|
||||
[field]: value,
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
return node;
|
||||
}),
|
||||
);
|
||||
},
|
||||
[setNodes],
|
||||
);
|
||||
|
||||
const openRunnerInput = () => setIsRunnerInputOpen(true);
|
||||
const openRunnerOutput = () => setIsRunnerOutputOpen(true);
|
||||
|
||||
const runOrOpenInput = () => {
|
||||
const { inputs } = getBlockInputsAndOutputs();
|
||||
if (inputs.length > 0) {
|
||||
openRunnerInput();
|
||||
} else {
|
||||
requestSaveAndRun();
|
||||
}
|
||||
};
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
openRunnerInput,
|
||||
openRunnerOutput,
|
||||
runOrOpenInput,
|
||||
}));
|
||||
|
||||
return (
|
||||
<>
|
||||
<RunnerInputUI
|
||||
isOpen={isRunnerInputOpen}
|
||||
onClose={() => setIsRunnerInputOpen(false)}
|
||||
blockInputs={getBlockInputsAndOutputs().inputs}
|
||||
onInputChange={handleInputChange}
|
||||
onRun={() => {
|
||||
setIsRunnerInputOpen(false);
|
||||
requestSaveAndRun();
|
||||
}}
|
||||
isRunning={isRunning}
|
||||
/>
|
||||
<RunnerOutputUI
|
||||
isOpen={isRunnerOutputOpen}
|
||||
onClose={() => setIsRunnerOutputOpen(false)}
|
||||
blockOutputs={getBlockInputsAndOutputs().outputs}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
RunnerUIWrapper.displayName = "RunnerUIWrapper";
|
||||
|
||||
export default RunnerUIWrapper;
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
import FeaturedAgentsTable from "./FeaturedAgentsTable";
|
||||
import { AdminAddFeaturedAgentDialog } from "./AdminAddFeaturedAgentDialog";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
export default async function AdminFeaturedAgentsControl({
|
||||
className,
|
||||
@@ -55,9 +56,15 @@ export default async function AdminFeaturedAgentsControl({
|
||||
component: <Button>Remove</Button>,
|
||||
action: async (rows) => {
|
||||
"use server";
|
||||
const all = rows.map((row) => removeFeaturedAgent(row.id));
|
||||
await Promise.all(all);
|
||||
revalidatePath("/marketplace");
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"removeFeaturedAgent",
|
||||
{},
|
||||
async () => {
|
||||
const all = rows.map((row) => removeFeaturedAgent(row.id));
|
||||
await Promise.all(all);
|
||||
revalidatePath("/marketplace");
|
||||
},
|
||||
);
|
||||
},
|
||||
},
|
||||
]}
|
||||
|
||||
@@ -2,16 +2,23 @@
|
||||
import AutoGPTServerAPI from "@/lib/autogpt-server-api";
|
||||
import MarketplaceAPI from "@/lib/marketplace-api";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
export async function approveAgent(
|
||||
agentId: string,
|
||||
version: number,
|
||||
comment: string,
|
||||
) {
|
||||
const api = new MarketplaceAPI();
|
||||
await api.approveAgentSubmission(agentId, version, comment);
|
||||
console.debug(`Approving agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"approveAgent",
|
||||
{},
|
||||
async () => {
|
||||
const api = new MarketplaceAPI();
|
||||
await api.approveAgentSubmission(agentId, version, comment);
|
||||
console.debug(`Approving agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function rejectAgent(
|
||||
@@ -19,67 +26,117 @@ export async function rejectAgent(
|
||||
version: number,
|
||||
comment: string,
|
||||
) {
|
||||
const api = new MarketplaceAPI();
|
||||
await api.rejectAgentSubmission(agentId, version, comment);
|
||||
console.debug(`Rejecting agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"rejectAgent",
|
||||
{},
|
||||
async () => {
|
||||
const api = new MarketplaceAPI();
|
||||
await api.rejectAgentSubmission(agentId, version, comment);
|
||||
console.debug(`Rejecting agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function getReviewableAgents() {
|
||||
const api = new MarketplaceAPI();
|
||||
return api.getAgentSubmissions();
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"getReviewableAgents",
|
||||
{},
|
||||
async () => {
|
||||
const api = new MarketplaceAPI();
|
||||
return api.getAgentSubmissions();
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function getFeaturedAgents(
|
||||
page: number = 1,
|
||||
pageSize: number = 10,
|
||||
) {
|
||||
const api = new MarketplaceAPI();
|
||||
const featured = await api.getFeaturedAgents(page, pageSize);
|
||||
console.debug(`Getting featured agents ${featured.agents.length}`);
|
||||
return featured;
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"getFeaturedAgents",
|
||||
{},
|
||||
async () => {
|
||||
const api = new MarketplaceAPI();
|
||||
const featured = await api.getFeaturedAgents(page, pageSize);
|
||||
console.debug(`Getting featured agents ${featured.agents.length}`);
|
||||
return featured;
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function getFeaturedAgent(agentId: string) {
|
||||
const api = new MarketplaceAPI();
|
||||
const featured = await api.getFeaturedAgent(agentId);
|
||||
console.debug(`Getting featured agent ${featured.agentId}`);
|
||||
return featured;
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"getFeaturedAgent",
|
||||
{},
|
||||
async () => {
|
||||
const api = new MarketplaceAPI();
|
||||
const featured = await api.getFeaturedAgent(agentId);
|
||||
console.debug(`Getting featured agent ${featured.agentId}`);
|
||||
return featured;
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function addFeaturedAgent(
|
||||
agentId: string,
|
||||
categories: string[] = ["featured"],
|
||||
) {
|
||||
const api = new MarketplaceAPI();
|
||||
await api.addFeaturedAgent(agentId, categories);
|
||||
console.debug(`Adding featured agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"addFeaturedAgent",
|
||||
{},
|
||||
async () => {
|
||||
const api = new MarketplaceAPI();
|
||||
await api.addFeaturedAgent(agentId, categories);
|
||||
console.debug(`Adding featured agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function removeFeaturedAgent(
|
||||
agentId: string,
|
||||
categories: string[] = ["featured"],
|
||||
) {
|
||||
const api = new MarketplaceAPI();
|
||||
await api.removeFeaturedAgent(agentId, categories);
|
||||
console.debug(`Removing featured agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"removeFeaturedAgent",
|
||||
{},
|
||||
async () => {
|
||||
const api = new MarketplaceAPI();
|
||||
await api.removeFeaturedAgent(agentId, categories);
|
||||
console.debug(`Removing featured agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function getCategories() {
|
||||
const api = new MarketplaceAPI();
|
||||
const categories = await api.getCategories();
|
||||
console.debug(`Getting categories ${categories.unique_categories.length}`);
|
||||
return categories;
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"getCategories",
|
||||
{},
|
||||
async () => {
|
||||
const api = new MarketplaceAPI();
|
||||
const categories = await api.getCategories();
|
||||
console.debug(
|
||||
`Getting categories ${categories.unique_categories.length}`,
|
||||
);
|
||||
return categories;
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
export async function getNotFeaturedAgents(
|
||||
page: number = 1,
|
||||
pageSize: number = 100,
|
||||
) {
|
||||
const api = new MarketplaceAPI();
|
||||
const agents = await api.getNotFeaturedAgents(page, pageSize);
|
||||
console.debug(`Getting not featured agents ${agents.agents.length}`);
|
||||
return agents;
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"getNotFeaturedAgents",
|
||||
{},
|
||||
async () => {
|
||||
const api = new MarketplaceAPI();
|
||||
const agents = await api.getNotFeaturedAgents(page, pageSize);
|
||||
console.debug(`Getting not featured agents ${agents.agents.length}`);
|
||||
return agents;
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import React from "react";
|
||||
export type Control = {
|
||||
icon: React.ReactNode;
|
||||
label: string;
|
||||
disabled?: boolean;
|
||||
onClick: () => void;
|
||||
};
|
||||
|
||||
@@ -50,15 +51,18 @@ export const ControlPanel = ({
|
||||
{controls.map((control, index) => (
|
||||
<Tooltip key={index} delayDuration={500}>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => control.onClick()}
|
||||
data-id={`control-button-${index}`}
|
||||
>
|
||||
{control.icon}
|
||||
<span className="sr-only">{control.label}</span>
|
||||
</Button>
|
||||
<div>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
onClick={() => control.onClick()}
|
||||
data-id={`control-button-${index}`}
|
||||
disabled={control.disabled || false}
|
||||
>
|
||||
{control.icon}
|
||||
<span className="sr-only">{control.label}</span>
|
||||
</Button>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="right">{control.label}</TooltipContent>
|
||||
</Tooltip>
|
||||
|
||||
@@ -18,6 +18,8 @@ import {
|
||||
|
||||
interface SaveControlProps {
|
||||
agentMeta: GraphMeta | null;
|
||||
agentName: string;
|
||||
agentDescription: string;
|
||||
onSave: (isTemplate: boolean | undefined) => void;
|
||||
onNameChange: (name: string) => void;
|
||||
onDescriptionChange: (description: string) => void;
|
||||
@@ -35,7 +37,9 @@ interface SaveControlProps {
|
||||
export const SaveControl = ({
|
||||
agentMeta,
|
||||
onSave,
|
||||
agentName,
|
||||
onNameChange,
|
||||
agentDescription,
|
||||
onDescriptionChange,
|
||||
}: SaveControlProps) => {
|
||||
/**
|
||||
@@ -75,7 +79,7 @@ export const SaveControl = ({
|
||||
id="name"
|
||||
placeholder="Enter your agent name"
|
||||
className="col-span-3"
|
||||
defaultValue={agentMeta?.name || ""}
|
||||
value={agentName}
|
||||
onChange={(e) => onNameChange(e.target.value)}
|
||||
/>
|
||||
<Label htmlFor="description">Description</Label>
|
||||
@@ -83,9 +87,21 @@ export const SaveControl = ({
|
||||
id="description"
|
||||
placeholder="Your agent description"
|
||||
className="col-span-3"
|
||||
defaultValue={agentMeta?.description || ""}
|
||||
value={agentDescription}
|
||||
onChange={(e) => onDescriptionChange(e.target.value)}
|
||||
/>
|
||||
{agentMeta?.version && (
|
||||
<>
|
||||
<Label htmlFor="version">Version</Label>
|
||||
<Input
|
||||
id="version"
|
||||
placeholder="Version"
|
||||
className="col-span-3"
|
||||
value={agentMeta?.version || "-"}
|
||||
disabled
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</CardContent>
|
||||
<CardFooter className="flex flex-col items-stretch gap-2">
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
"use server";
|
||||
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import MarketplaceAPI, { AnalyticsEvent } from "@/lib/marketplace-api";
|
||||
|
||||
export async function makeAnalyticsEvent(event: AnalyticsEvent) {
|
||||
const apiUrl = process.env.AGPT_SERVER_API_URL;
|
||||
const api = new MarketplaceAPI();
|
||||
await api.makeAnalyticsEvent(event);
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"makeAnalyticsEvent",
|
||||
{},
|
||||
async () => {
|
||||
const apiUrl = process.env.AGPT_SERVER_API_URL;
|
||||
const api = new MarketplaceAPI();
|
||||
await api.makeAnalyticsEvent(event);
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -380,7 +380,7 @@ const NodeKeyValueInput: FC<{
|
||||
<Input
|
||||
type="text"
|
||||
placeholder="Value"
|
||||
value={value ?? ""}
|
||||
defaultValue={value ?? ""}
|
||||
onBlur={(e) =>
|
||||
updateKeyValuePairs(
|
||||
keyValuePairs.toSpliced(index, 1, {
|
||||
@@ -563,7 +563,7 @@ const NodeStringInput: FC<{
|
||||
<Input
|
||||
type="text"
|
||||
id={selfKey}
|
||||
value={schema.secret && value ? "********" : value}
|
||||
defaultValue={schema.secret && value ? "********" : value}
|
||||
readOnly={schema.secret}
|
||||
placeholder={
|
||||
schema?.placeholder || `Enter ${beautifyString(displayName)}`
|
||||
@@ -658,7 +658,7 @@ const NodeNumberInput: FC<{
|
||||
<Input
|
||||
type="number"
|
||||
id={selfKey}
|
||||
value={value}
|
||||
defaultValue={value}
|
||||
onBlur={(e) => handleInputChange(selfKey, parseFloat(e.target.value))}
|
||||
placeholder={
|
||||
schema.placeholder || `Enter ${beautifyString(displayName)}`
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
import React from "react";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
|
||||
interface InputBlockProps {
|
||||
id: string;
|
||||
name: string;
|
||||
description?: string;
|
||||
value: string;
|
||||
placeholder_values?: any[];
|
||||
onInputChange: (id: string, field: string, value: string) => void;
|
||||
}
|
||||
|
||||
export function InputBlock({
|
||||
id,
|
||||
name,
|
||||
description,
|
||||
value,
|
||||
placeholder_values,
|
||||
onInputChange,
|
||||
}: InputBlockProps) {
|
||||
return (
|
||||
<div className="space-y-1">
|
||||
<h3 className="text-base font-semibold">{name || "Unnamed Input"}</h3>
|
||||
{description && <p className="text-sm text-gray-600">{description}</p>}
|
||||
<div>
|
||||
{placeholder_values && placeholder_values.length > 1 ? (
|
||||
<Select
|
||||
onValueChange={(value) => onInputChange(id, "value", value)}
|
||||
value={value}
|
||||
>
|
||||
<SelectTrigger className="w-full">
|
||||
<SelectValue placeholder="Select a value" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{placeholder_values.map((placeholder, index) => (
|
||||
<SelectItem key={index} value={placeholder.toString()}>
|
||||
{placeholder.toString()}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
) : (
|
||||
<Input
|
||||
id={`${id}-Value`}
|
||||
value={value}
|
||||
onChange={(e) => onInputChange(id, "value", e.target.value)}
|
||||
placeholder={placeholder_values?.[0]?.toString() || "Enter value"}
|
||||
className="w-full"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
import React from "react";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { InputBlock } from "./RunnerInputBlock";
|
||||
import { BlockInput } from "./RunnerInputUI";
|
||||
|
||||
interface InputListProps {
|
||||
blockInputs: BlockInput[];
|
||||
onInputChange: (nodeId: string, field: string, value: string) => void;
|
||||
}
|
||||
|
||||
export function InputList({ blockInputs, onInputChange }: InputListProps) {
|
||||
return (
|
||||
<ScrollArea className="h-[20vh] overflow-auto pr-4 sm:h-[30vh] md:h-[40vh] lg:h-[50vh]">
|
||||
<div className="space-y-4">
|
||||
{blockInputs && blockInputs.length > 0 ? (
|
||||
blockInputs.map((block) => (
|
||||
<InputBlock
|
||||
key={block.id}
|
||||
id={block.id}
|
||||
name={block.hardcodedValues.name}
|
||||
description={block.hardcodedValues.description}
|
||||
value={block.hardcodedValues.value?.toString() || ""}
|
||||
placeholder_values={block.hardcodedValues.placeholder_values}
|
||||
onInputChange={onInputChange}
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
<p>No input blocks available.</p>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
import React from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { BlockIORootSchema } from "@/lib/autogpt-server-api/types";
|
||||
import { InputList } from "./RunnerInputList";
|
||||
|
||||
export interface BlockInput {
|
||||
id: string;
|
||||
inputSchema: BlockIORootSchema;
|
||||
hardcodedValues: {
|
||||
name: string;
|
||||
description: string;
|
||||
value: any;
|
||||
placeholder_values?: any[];
|
||||
limit_to_placeholder_values?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
interface RunSettingsUiProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
blockInputs: BlockInput[];
|
||||
onInputChange: (nodeId: string, field: string, value: string) => void;
|
||||
onRun: () => void;
|
||||
isRunning: boolean;
|
||||
}
|
||||
|
||||
export function RunnerInputUI({
|
||||
isOpen,
|
||||
onClose,
|
||||
blockInputs,
|
||||
onInputChange,
|
||||
onRun,
|
||||
isRunning,
|
||||
}: RunSettingsUiProps) {
|
||||
const handleRun = () => {
|
||||
onRun();
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog open={isOpen} onOpenChange={onClose}>
|
||||
<DialogContent className="flex max-h-[80vh] flex-col overflow-hidden sm:max-w-[400px] md:max-w-[500px] lg:max-w-[600px]">
|
||||
<DialogHeader className="px-4 py-4">
|
||||
<DialogTitle className="text-2xl">Run Settings</DialogTitle>
|
||||
<DialogDescription className="mt-2 text-sm">
|
||||
Configure settings for running your agent.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className="flex-grow overflow-y-auto px-4 py-4">
|
||||
<InputList blockInputs={blockInputs} onInputChange={onInputChange} />
|
||||
</div>
|
||||
<DialogFooter className="px-6 py-4">
|
||||
<Button
|
||||
onClick={handleRun}
|
||||
className="px-8 py-2 text-lg"
|
||||
disabled={isRunning}
|
||||
>
|
||||
{isRunning ? "Running..." : "Run"}
|
||||
</Button>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
export default RunnerInputUI;
|
||||
@@ -0,0 +1,94 @@
|
||||
import React from "react";
|
||||
import {
|
||||
Sheet,
|
||||
SheetContent,
|
||||
SheetHeader,
|
||||
SheetTitle,
|
||||
SheetDescription,
|
||||
} from "@/components/ui/sheet";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { BlockIORootSchema } from "@/lib/autogpt-server-api/types";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
|
||||
interface BlockOutput {
|
||||
id: string;
|
||||
outputSchema: BlockIORootSchema;
|
||||
hardcodedValues: {
|
||||
name: string;
|
||||
description: string;
|
||||
};
|
||||
result?: any;
|
||||
}
|
||||
|
||||
interface OutputModalProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
blockOutputs: BlockOutput[];
|
||||
}
|
||||
|
||||
const formatOutput = (output: any): string => {
|
||||
if (typeof output === "object") {
|
||||
try {
|
||||
return JSON.stringify(output, null, 2);
|
||||
} catch (error) {
|
||||
return `Error formatting output: ${(error as Error).message}`;
|
||||
}
|
||||
}
|
||||
return String(output);
|
||||
};
|
||||
|
||||
export function RunnerOutputUI({
|
||||
isOpen,
|
||||
onClose,
|
||||
blockOutputs,
|
||||
}: OutputModalProps) {
|
||||
return (
|
||||
<Sheet open={isOpen} onOpenChange={onClose}>
|
||||
<SheetContent
|
||||
side="right"
|
||||
className="flex h-full w-full flex-col overflow-hidden sm:max-w-[500px]"
|
||||
>
|
||||
<SheetHeader className="px-2 py-2">
|
||||
<SheetTitle className="text-xl">Run Outputs</SheetTitle>
|
||||
<SheetDescription className="mt-1 text-sm">
|
||||
View the outputs from your agent run.
|
||||
</SheetDescription>
|
||||
</SheetHeader>
|
||||
<div className="flex-grow overflow-y-auto px-2 py-2">
|
||||
<ScrollArea className="h-full overflow-auto pr-4">
|
||||
<div className="space-y-4">
|
||||
{blockOutputs && blockOutputs.length > 0 ? (
|
||||
blockOutputs.map((block) => (
|
||||
<div key={block.id} className="space-y-1">
|
||||
<Label className="text-base font-semibold">
|
||||
{block.hardcodedValues.name || "Unnamed Output"}
|
||||
</Label>
|
||||
|
||||
{block.hardcodedValues.description && (
|
||||
<Label className="block text-sm text-gray-600">
|
||||
{block.hardcodedValues.description}
|
||||
</Label>
|
||||
)}
|
||||
|
||||
<div className="rounded-md bg-gray-100 p-2">
|
||||
<Textarea
|
||||
readOnly
|
||||
value={formatOutput(block.result ?? "No output yet")}
|
||||
className="resize-none whitespace-pre-wrap break-words border-none bg-transparent text-sm"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<p>No output blocks available.</p>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</div>
|
||||
</SheetContent>
|
||||
</Sheet>
|
||||
);
|
||||
}
|
||||
|
||||
export default RunnerOutputUI;
|
||||
@@ -6,20 +6,7 @@ export interface InputProps
|
||||
extends React.InputHTMLAttributes<HTMLInputElement> {}
|
||||
|
||||
const Input = React.forwardRef<HTMLInputElement, InputProps>(
|
||||
({ className, type, value, ...props }, ref) => {
|
||||
// This ref allows the `Input` component to be both controlled and uncontrolled.
|
||||
// The HTMLvalue will only be updated if the value prop changes, but the user can still type in the input.
|
||||
ref = ref || React.createRef<HTMLInputElement>();
|
||||
React.useEffect(() => {
|
||||
if (
|
||||
ref &&
|
||||
ref.current &&
|
||||
ref.current.value !== value &&
|
||||
type !== "file"
|
||||
) {
|
||||
ref.current.value = value;
|
||||
}
|
||||
}, [value, type, ref]);
|
||||
({ className, type, ...props }, ref) => {
|
||||
return (
|
||||
<input
|
||||
type={type}
|
||||
@@ -29,7 +16,6 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(
|
||||
className,
|
||||
)}
|
||||
ref={ref}
|
||||
defaultValue={type !== "file" ? value : undefined}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
import { Connection, MarkerType } from "@xyflow/react";
|
||||
import Ajv from "ajv";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useRouter, useSearchParams, usePathname } from "next/navigation";
|
||||
|
||||
const ajv = new Ajv({ strict: false, allErrors: true });
|
||||
|
||||
@@ -24,6 +25,11 @@ export default function useAgentGraph(
|
||||
template?: boolean,
|
||||
passDataToBeads?: boolean,
|
||||
) {
|
||||
const [router, searchParams, pathname] = [
|
||||
useRouter(),
|
||||
useSearchParams(),
|
||||
usePathname(),
|
||||
];
|
||||
const [savedAgent, setSavedAgent] = useState<Graph | null>(null);
|
||||
const [agentDescription, setAgentDescription] = useState<string>("");
|
||||
const [agentName, setAgentName] = useState<string>("");
|
||||
@@ -133,8 +139,8 @@ export default function useAgentGraph(
|
||||
id: node.id,
|
||||
type: "custom",
|
||||
position: {
|
||||
x: node.metadata.position.x,
|
||||
y: node.metadata.position.y,
|
||||
x: node?.metadata?.position?.x || 0,
|
||||
y: node?.metadata?.position?.y || 0,
|
||||
},
|
||||
data: {
|
||||
block_id: block.id,
|
||||
@@ -307,7 +313,7 @@ export default function useAgentGraph(
|
||||
|
||||
(template ? api.getTemplate(flowID) : api.getGraph(flowID)).then(
|
||||
(graph) => {
|
||||
console.log("Loading graph");
|
||||
console.debug("Loading graph");
|
||||
loadGraph(graph);
|
||||
},
|
||||
);
|
||||
@@ -638,31 +644,59 @@ export default function useAgentGraph(
|
||||
links: links,
|
||||
};
|
||||
|
||||
if (savedAgent && deepEquals(payload, savedAgent)) {
|
||||
console.debug(
|
||||
"No need to save: Graph is the same as version on server",
|
||||
);
|
||||
// Trigger state change
|
||||
setSavedAgent(savedAgent);
|
||||
return;
|
||||
// To avoid saving the same graph, we compare the payload with the saved agent.
|
||||
// Differences in IDs are ignored.
|
||||
const comparedPayload = {
|
||||
...(({ id, ...rest }) => rest)(payload),
|
||||
nodes: payload.nodes.map(
|
||||
({ id, data, input_nodes, output_nodes, ...rest }) => rest,
|
||||
),
|
||||
links: payload.links.map(({ source_id, sink_id, ...rest }) => rest),
|
||||
};
|
||||
const comparedSavedAgent = {
|
||||
name: savedAgent?.name,
|
||||
description: savedAgent?.description,
|
||||
nodes: savedAgent?.nodes.map((v) => ({
|
||||
block_id: v.block_id,
|
||||
input_default: v.input_default,
|
||||
metadata: v.metadata,
|
||||
})),
|
||||
links: savedAgent?.links.map((v) => ({
|
||||
sink_name: v.sink_name,
|
||||
source_name: v.source_name,
|
||||
})),
|
||||
};
|
||||
|
||||
let newSavedAgent = null;
|
||||
if (savedAgent && deepEquals(comparedPayload, comparedSavedAgent)) {
|
||||
console.warn("No need to save: Graph is the same as version on server");
|
||||
newSavedAgent = savedAgent;
|
||||
} else {
|
||||
console.debug(
|
||||
"Saving new Graph version; old vs new:",
|
||||
savedAgent,
|
||||
comparedPayload,
|
||||
payload,
|
||||
);
|
||||
setNodesSyncedWithSavedAgent(false);
|
||||
|
||||
newSavedAgent = savedAgent
|
||||
? await (savedAgent.is_template
|
||||
? api.updateTemplate(savedAgent.id, payload)
|
||||
: api.updateGraph(savedAgent.id, payload))
|
||||
: await (asTemplate
|
||||
? api.createTemplate(payload)
|
||||
: api.createGraph(payload));
|
||||
|
||||
console.debug("Response from the API:", newSavedAgent);
|
||||
}
|
||||
|
||||
setNodesSyncedWithSavedAgent(false);
|
||||
|
||||
const newSavedAgent = savedAgent
|
||||
? await (savedAgent.is_template
|
||||
? api.updateTemplate(savedAgent.id, payload)
|
||||
: api.updateGraph(savedAgent.id, payload))
|
||||
: await (asTemplate
|
||||
? api.createTemplate(payload)
|
||||
: api.createGraph(payload));
|
||||
console.debug("Response from the API:", newSavedAgent);
|
||||
// Route the URL to the new flow ID if it's a new agent.
|
||||
if (!savedAgent) {
|
||||
const path = new URLSearchParams(searchParams);
|
||||
path.set("flowID", newSavedAgent.id);
|
||||
router.push(`${pathname}?${path.toString()}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Update the node IDs on the frontend
|
||||
setSavedAgent(newSavedAgent);
|
||||
|
||||
13
rnd/autogpt_builder/src/instrumentation.ts
Normal file
13
rnd/autogpt_builder/src/instrumentation.ts
Normal file
@@ -0,0 +1,13 @@
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
export async function register() {
|
||||
if (process.env.NEXT_RUNTIME === "nodejs") {
|
||||
await import("../sentry.server.config");
|
||||
}
|
||||
|
||||
if (process.env.NEXT_RUNTIME === "edge") {
|
||||
await import("../sentry.edge.config");
|
||||
}
|
||||
}
|
||||
|
||||
export const onRequestError = Sentry.captureRequestError;
|
||||
@@ -7,6 +7,7 @@ export function createClient() {
|
||||
process.env.NEXT_PUBLIC_SUPABASE_ANON_KEY!,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error("error creating client", error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,15 +24,16 @@ export function deepEquals(x: any, y: any): boolean {
|
||||
const ok = Object.keys,
|
||||
tx = typeof x,
|
||||
ty = typeof y;
|
||||
return (
|
||||
|
||||
const res =
|
||||
x &&
|
||||
y &&
|
||||
tx === ty &&
|
||||
(tx === "object"
|
||||
? ok(x).length === ok(y).length &&
|
||||
ok(x).every((key) => deepEquals(x[key], y[key]))
|
||||
: x === y)
|
||||
);
|
||||
: x === y);
|
||||
return res;
|
||||
}
|
||||
|
||||
/** Get tailwind text color class from type name */
|
||||
@@ -184,7 +185,7 @@ export const categoryColorMap: Record<string, string> = {
|
||||
SEARCH: "bg-blue-300/[.7]",
|
||||
BASIC: "bg-purple-300/[.7]",
|
||||
INPUT: "bg-cyan-300/[.7]",
|
||||
OUTPUT: "bg-brown-300/[.7]",
|
||||
OUTPUT: "bg-red-300/[.7]",
|
||||
LOGIC: "bg-teal-300/[.7]",
|
||||
};
|
||||
|
||||
@@ -194,3 +195,10 @@ export function getPrimaryCategoryColor(categories: Category[]): string {
|
||||
}
|
||||
return categoryColorMap[categories[0].category] || "bg-gray-300/[.7]";
|
||||
}
|
||||
|
||||
export function filterBlocksByType<T>(
|
||||
blocks: T[],
|
||||
predicate: (block: T) => boolean,
|
||||
): T[] {
|
||||
return blocks.filter(predicate);
|
||||
}
|
||||
|
||||
@@ -1,16 +1,23 @@
|
||||
import { redirect } from "next/navigation";
|
||||
import getServerUser from "@/hooks/getServerUser";
|
||||
import React from "react";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
|
||||
export async function withRoleAccess(allowedRoles: string[]) {
|
||||
"use server";
|
||||
return async function <T extends React.ComponentType<any>>(Component: T) {
|
||||
const { user, role, error } = await getServerUser();
|
||||
return await Sentry.withServerActionInstrumentation(
|
||||
"withRoleAccess",
|
||||
{},
|
||||
async () => {
|
||||
return async function <T extends React.ComponentType<any>>(Component: T) {
|
||||
const { user, role, error } = await getServerUser();
|
||||
|
||||
if (error || !user || !role || !allowedRoles.includes(role)) {
|
||||
redirect("/unauthorized");
|
||||
}
|
||||
if (error || !user || !role || !allowedRoles.includes(role)) {
|
||||
redirect("/unauthorized");
|
||||
}
|
||||
|
||||
return Component;
|
||||
};
|
||||
return Component;
|
||||
};
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field, SecretStr, field_serializer
|
||||
class _BaseCredentials(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
provider: str
|
||||
title: str
|
||||
title: Optional[str]
|
||||
|
||||
@field_serializer("*")
|
||||
def dump_secret_strings(value: Any, _info):
|
||||
@@ -18,6 +18,8 @@ class _BaseCredentials(BaseModel):
|
||||
|
||||
class OAuth2Credentials(_BaseCredentials):
|
||||
type: Literal["oauth2"] = "oauth2"
|
||||
username: Optional[str]
|
||||
"""Username of the third-party service user that these credentials belong to"""
|
||||
access_token: SecretStr
|
||||
access_token_expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the access token expires (if at all)"""
|
||||
|
||||
@@ -9,7 +9,8 @@ REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
AUTH_ENABLED=false
|
||||
ENABLE_AUTH=false
|
||||
ENABLE_CREDIT=false
|
||||
APP_ENV="local"
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
@@ -17,6 +17,10 @@ ENV POETRY_VERSION=1.8.3 \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false \
|
||||
PATH="$POETRY_HOME/bin:$PATH"
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
RUN pip3 install poetry
|
||||
|
||||
# Copy and install dependencies
|
||||
@@ -41,6 +45,10 @@ ENV POETRY_VERSION=1.8.3 \
|
||||
POETRY_VIRTUALENVS_CREATE=false \
|
||||
PATH="$POETRY_HOME/bin:$PATH"
|
||||
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3.11 /usr/local/lib/python3.11
|
||||
|
||||
@@ -14,7 +14,8 @@ class ReadCsvBlock(Block):
|
||||
skip_columns: list[str] = []
|
||||
|
||||
class Output(BlockSchema):
|
||||
data: dict[str, str]
|
||||
row: dict[str, str]
|
||||
all_data: list[dict[str, str]]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -27,8 +28,15 @@ class ReadCsvBlock(Block):
|
||||
"contents": "a, b, c\n1,2,3\n4,5,6",
|
||||
},
|
||||
test_output=[
|
||||
("data", {"a": "1", "b": "2", "c": "3"}),
|
||||
("data", {"a": "4", "b": "5", "c": "6"}),
|
||||
("row", {"a": "1", "b": "2", "c": "3"}),
|
||||
("row", {"a": "4", "b": "5", "c": "6"}),
|
||||
(
|
||||
"all_data",
|
||||
[
|
||||
{"a": "1", "b": "2", "c": "3"},
|
||||
{"a": "4", "b": "5", "c": "6"},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -53,8 +61,7 @@ class ReadCsvBlock(Block):
|
||||
for _ in range(input_data.skip_rows):
|
||||
next(reader)
|
||||
|
||||
# join the data with the header
|
||||
for row in reader:
|
||||
def process_row(row):
|
||||
data = {}
|
||||
for i, value in enumerate(row):
|
||||
if i not in input_data.skip_columns:
|
||||
@@ -62,4 +69,12 @@ class ReadCsvBlock(Block):
|
||||
data[header[i]] = value.strip() if input_data.strip else value
|
||||
else:
|
||||
data[str(i)] = value.strip() if input_data.strip else value
|
||||
yield "data", data
|
||||
return data
|
||||
|
||||
all_data = []
|
||||
for row in reader:
|
||||
processed_row = process_row(row)
|
||||
all_data.append(processed_row)
|
||||
yield "row", processed_row
|
||||
|
||||
yield "all_data", all_data
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import List, NamedTuple
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, List, NamedTuple
|
||||
|
||||
import anthropic
|
||||
import ollama
|
||||
@@ -24,6 +25,7 @@ LlmApiKeys = {
|
||||
class ModelMetadata(NamedTuple):
|
||||
provider: str
|
||||
context_window: int
|
||||
cost_factor: int
|
||||
|
||||
|
||||
class LlmModel(str, Enum):
|
||||
@@ -55,26 +57,29 @@ class LlmModel(str, Enum):
|
||||
|
||||
|
||||
MODEL_METADATA = {
|
||||
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000),
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000),
|
||||
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000),
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385),
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000),
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000),
|
||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192),
|
||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192),
|
||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768),
|
||||
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192),
|
||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192),
|
||||
LlmModel.LLAMA3_1_405B: ModelMetadata(
|
||||
"groq", 8192
|
||||
), # Limited to 16k during preview
|
||||
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.GPT4O_MINI: ModelMetadata("openai", 128000, cost_factor=10),
|
||||
LlmModel.GPT4O: ModelMetadata("openai", 128000, cost_factor=12),
|
||||
LlmModel.GPT4_TURBO: ModelMetadata("openai", 128000, cost_factor=11),
|
||||
LlmModel.GPT3_5_TURBO: ModelMetadata("openai", 16385, cost_factor=8),
|
||||
LlmModel.CLAUDE_3_5_SONNET: ModelMetadata("anthropic", 200000, cost_factor=14),
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata("anthropic", 200000, cost_factor=13),
|
||||
LlmModel.LLAMA3_8B: ModelMetadata("groq", 8192, cost_factor=6),
|
||||
LlmModel.LLAMA3_70B: ModelMetadata("groq", 8192, cost_factor=9),
|
||||
LlmModel.MIXTRAL_8X7B: ModelMetadata("groq", 32768, cost_factor=7),
|
||||
LlmModel.GEMMA_7B: ModelMetadata("groq", 8192, cost_factor=6),
|
||||
LlmModel.GEMMA2_9B: ModelMetadata("groq", 8192, cost_factor=7),
|
||||
LlmModel.LLAMA3_1_405B: ModelMetadata("groq", 8192, cost_factor=10),
|
||||
# Limited to 16k during preview
|
||||
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072, cost_factor=15),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072, cost_factor=13),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192, cost_factor=7),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192, cost_factor=11),
|
||||
}
|
||||
|
||||
for model in LlmModel:
|
||||
if model not in MODEL_METADATA:
|
||||
raise ValueError(f"Missing MODEL_METADATA metadata for model: {model}")
|
||||
|
||||
|
||||
class AIStructuredResponseGeneratorBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
@@ -89,7 +94,7 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, str]
|
||||
response: dict[str, Any]
|
||||
error: str
|
||||
|
||||
def __init__(self):
|
||||
@@ -135,16 +140,33 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
)
|
||||
return response.choices[0].message.content or ""
|
||||
elif provider == "anthropic":
|
||||
sysprompt = "".join([p["content"] for p in prompt if p["role"] == "system"])
|
||||
usrprompt = [p for p in prompt if p["role"] == "user"]
|
||||
system_messages = [p["content"] for p in prompt if p["role"] == "system"]
|
||||
sysprompt = " ".join(system_messages)
|
||||
|
||||
messages = []
|
||||
last_role = None
|
||||
for p in prompt:
|
||||
if p["role"] in ["user", "assistant"]:
|
||||
if p["role"] != last_role:
|
||||
messages.append({"role": p["role"], "content": p["content"]})
|
||||
last_role = p["role"]
|
||||
else:
|
||||
# If the role is the same as the last one, combine the content
|
||||
messages[-1]["content"] += "\n" + p["content"]
|
||||
|
||||
client = anthropic.Anthropic(api_key=api_key)
|
||||
response = client.messages.create(
|
||||
model=model.value,
|
||||
max_tokens=4096,
|
||||
system=sysprompt,
|
||||
messages=usrprompt, # type: ignore
|
||||
)
|
||||
return response.content[0].text if response.content else ""
|
||||
try:
|
||||
response = client.messages.create(
|
||||
model=model.value,
|
||||
max_tokens=4096,
|
||||
system=sysprompt,
|
||||
messages=messages,
|
||||
)
|
||||
return response.content[0].text if response.content else ""
|
||||
except anthropic.APIError as e:
|
||||
error_message = f"Anthropic API error: {str(e)}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
elif provider == "groq":
|
||||
client = Groq(api_key=api_key)
|
||||
response_format = {"type": "json_object"} if json_format else None
|
||||
@@ -195,14 +217,16 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
|
||||
prompt.append({"role": "user", "content": input_data.prompt})
|
||||
|
||||
def parse_response(resp: str) -> tuple[dict[str, str], str | None]:
|
||||
def parse_response(resp: str) -> tuple[dict[str, Any], str | None]:
|
||||
try:
|
||||
parsed = json.loads(resp)
|
||||
if not isinstance(parsed, dict):
|
||||
return {}, f"Expected a dictionary, but got {type(parsed)}"
|
||||
miss_keys = set(input_data.expected_format.keys()) - set(parsed.keys())
|
||||
if miss_keys:
|
||||
return parsed, f"Missing keys: {miss_keys}"
|
||||
return parsed, None
|
||||
except Exception as e:
|
||||
except JSONDecodeError as e:
|
||||
return {}, f"JSON decode error: {e}"
|
||||
|
||||
logger.info(f"LLM request: {prompt}")
|
||||
@@ -226,7 +250,16 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
if input_data.expected_format:
|
||||
parsed_dict, parsed_error = parse_response(response_text)
|
||||
if not parsed_error:
|
||||
yield "response", {k: str(v) for k, v in parsed_dict.items()}
|
||||
yield "response", {
|
||||
k: (
|
||||
json.loads(v)
|
||||
if isinstance(v, str)
|
||||
and v.startswith("[")
|
||||
and v.endswith("]")
|
||||
else (", ".join(v) if isinstance(v, list) else v)
|
||||
)
|
||||
for k, v in parsed_dict.items()
|
||||
}
|
||||
return
|
||||
else:
|
||||
yield "response", {"response": response_text}
|
||||
@@ -301,7 +334,7 @@ class AITextGeneratorBlock(Block):
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class TextSummarizerBlock(Block):
|
||||
class AITextSummarizerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
text: str
|
||||
model: LlmModel = LlmModel.GPT4_TURBO
|
||||
@@ -319,8 +352,8 @@ class TextSummarizerBlock(Block):
|
||||
id="c3d4e5f6-7g8h-9i0j-1k2l-m3n4o5p6q7r8",
|
||||
description="Utilize a Large Language Model (LLM) to summarize a long text.",
|
||||
categories={BlockCategory.AI, BlockCategory.TEXT},
|
||||
input_schema=TextSummarizerBlock.Input,
|
||||
output_schema=TextSummarizerBlock.Output,
|
||||
input_schema=AITextSummarizerBlock.Input,
|
||||
output_schema=AITextSummarizerBlock.Output,
|
||||
test_input={"text": "Lorem ipsum..." * 100},
|
||||
test_output=("summary", "Final summary of a long text"),
|
||||
test_mock={
|
||||
@@ -412,7 +445,7 @@ class TextSummarizerBlock(Block):
|
||||
else:
|
||||
# If combined summaries are still too long, recursively summarize
|
||||
return self._run(
|
||||
TextSummarizerBlock.Input(
|
||||
AITextSummarizerBlock.Input(
|
||||
text=combined_text,
|
||||
api_key=input_data.api_key,
|
||||
model=input_data.model,
|
||||
|
||||
264
rnd/autogpt_server/autogpt_server/blocks/sampling.py
Normal file
264
rnd/autogpt_server/autogpt_server/blocks/sampling.py
Normal file
@@ -0,0 +1,264 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from autogpt_server.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from autogpt_server.data.model import SchemaField
|
||||
|
||||
|
||||
class SamplingMethod(str, Enum):
|
||||
RANDOM = "random"
|
||||
SYSTEMATIC = "systematic"
|
||||
TOP = "top"
|
||||
BOTTOM = "bottom"
|
||||
STRATIFIED = "stratified"
|
||||
WEIGHTED = "weighted"
|
||||
RESERVOIR = "reservoir"
|
||||
CLUSTER = "cluster"
|
||||
|
||||
|
||||
class DataSamplingBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
data: Union[Dict[str, Any], List[Union[dict, List[Any]]]] = SchemaField(
|
||||
description="The dataset to sample from. Can be a single dictionary, a list of dictionaries, or a list of lists.",
|
||||
placeholder="{'id': 1, 'value': 'a'} or [{'id': 1, 'value': 'a'}, {'id': 2, 'value': 'b'}, ...]",
|
||||
)
|
||||
sample_size: int = SchemaField(
|
||||
description="The number of samples to take from the dataset.",
|
||||
placeholder="10",
|
||||
default=10,
|
||||
)
|
||||
sampling_method: SamplingMethod = SchemaField(
|
||||
description="The method to use for sampling.",
|
||||
default=SamplingMethod.RANDOM,
|
||||
)
|
||||
accumulate: bool = SchemaField(
|
||||
description="Whether to accumulate data before sampling.",
|
||||
default=False,
|
||||
)
|
||||
random_seed: Optional[int] = SchemaField(
|
||||
description="Seed for random number generator (optional).",
|
||||
default=None,
|
||||
)
|
||||
stratify_key: Optional[str] = SchemaField(
|
||||
description="Key to use for stratified sampling (required for stratified sampling).",
|
||||
default=None,
|
||||
)
|
||||
weight_key: Optional[str] = SchemaField(
|
||||
description="Key to use for weighted sampling (required for weighted sampling).",
|
||||
default=None,
|
||||
)
|
||||
cluster_key: Optional[str] = SchemaField(
|
||||
description="Key to use for cluster sampling (required for cluster sampling).",
|
||||
default=None,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
sampled_data: List[Union[dict, List[Any]]] = SchemaField(
|
||||
description="The sampled subset of the input data."
|
||||
)
|
||||
sample_indices: List[int] = SchemaField(
|
||||
description="The indices of the sampled data in the original dataset."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4a448883-71fa-49cf-91cf-70d793bd7d87",
|
||||
description="This block samples data from a given dataset using various sampling methods.",
|
||||
categories={BlockCategory.LOGIC},
|
||||
input_schema=DataSamplingBlock.Input,
|
||||
output_schema=DataSamplingBlock.Output,
|
||||
test_input={
|
||||
"data": [
|
||||
{"id": i, "value": chr(97 + i), "group": i % 3} for i in range(10)
|
||||
],
|
||||
"sample_size": 3,
|
||||
"sampling_method": SamplingMethod.STRATIFIED,
|
||||
"accumulate": False,
|
||||
"random_seed": 42,
|
||||
"stratify_key": "group",
|
||||
},
|
||||
test_output=[
|
||||
(
|
||||
"sampled_data",
|
||||
[
|
||||
{"id": 0, "value": "a", "group": 0},
|
||||
{"id": 1, "value": "b", "group": 1},
|
||||
{"id": 8, "value": "i", "group": 2},
|
||||
],
|
||||
),
|
||||
("sample_indices", [0, 1, 8]),
|
||||
],
|
||||
)
|
||||
self.accumulated_data = []
|
||||
|
||||
def run(self, input_data: Input) -> BlockOutput:
|
||||
if input_data.accumulate:
|
||||
if isinstance(input_data.data, dict):
|
||||
self.accumulated_data.append(input_data.data)
|
||||
elif isinstance(input_data.data, list):
|
||||
self.accumulated_data.extend(input_data.data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type: {type(input_data.data)}")
|
||||
|
||||
# If we don't have enough data yet, return without sampling
|
||||
if len(self.accumulated_data) < input_data.sample_size:
|
||||
return
|
||||
|
||||
data_to_sample = self.accumulated_data
|
||||
else:
|
||||
# If not accumulating, use the input data directly
|
||||
data_to_sample = (
|
||||
input_data.data
|
||||
if isinstance(input_data.data, list)
|
||||
else [input_data.data]
|
||||
)
|
||||
|
||||
if input_data.random_seed is not None:
|
||||
random.seed(input_data.random_seed)
|
||||
|
||||
data_size = len(data_to_sample)
|
||||
|
||||
if input_data.sample_size > data_size:
|
||||
raise ValueError(
|
||||
f"Sample size ({input_data.sample_size}) cannot be larger than the dataset size ({data_size})."
|
||||
)
|
||||
|
||||
indices = []
|
||||
|
||||
if input_data.sampling_method == SamplingMethod.RANDOM:
|
||||
indices = random.sample(range(data_size), input_data.sample_size)
|
||||
elif input_data.sampling_method == SamplingMethod.SYSTEMATIC:
|
||||
step = data_size // input_data.sample_size
|
||||
start = random.randint(0, step - 1)
|
||||
indices = list(range(start, data_size, step))[: input_data.sample_size]
|
||||
elif input_data.sampling_method == SamplingMethod.TOP:
|
||||
indices = list(range(input_data.sample_size))
|
||||
elif input_data.sampling_method == SamplingMethod.BOTTOM:
|
||||
indices = list(range(data_size - input_data.sample_size, data_size))
|
||||
elif input_data.sampling_method == SamplingMethod.STRATIFIED:
|
||||
if not input_data.stratify_key:
|
||||
raise ValueError(
|
||||
"Stratify key must be provided for stratified sampling."
|
||||
)
|
||||
strata = defaultdict(list)
|
||||
for i, item in enumerate(data_to_sample):
|
||||
if isinstance(item, dict):
|
||||
strata_value = item.get(input_data.stratify_key)
|
||||
elif hasattr(item, input_data.stratify_key):
|
||||
strata_value = getattr(item, input_data.stratify_key)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Stratify key '{input_data.stratify_key}' not found in item {item}"
|
||||
)
|
||||
|
||||
if strata_value is None:
|
||||
raise ValueError(
|
||||
f"Stratify value for key '{input_data.stratify_key}' is None"
|
||||
)
|
||||
|
||||
strata[str(strata_value)].append(i)
|
||||
|
||||
# Calculate the number of samples to take from each stratum
|
||||
stratum_sizes = {
|
||||
k: max(1, int(len(v) / data_size * input_data.sample_size))
|
||||
for k, v in strata.items()
|
||||
}
|
||||
|
||||
# Adjust sizes to ensure we get exactly sample_size samples
|
||||
while sum(stratum_sizes.values()) != input_data.sample_size:
|
||||
if sum(stratum_sizes.values()) < input_data.sample_size:
|
||||
stratum_sizes[
|
||||
max(stratum_sizes, key=lambda k: stratum_sizes[k])
|
||||
] += 1
|
||||
else:
|
||||
stratum_sizes[
|
||||
max(stratum_sizes, key=lambda k: stratum_sizes[k])
|
||||
] -= 1
|
||||
|
||||
for stratum, size in stratum_sizes.items():
|
||||
indices.extend(random.sample(strata[stratum], size))
|
||||
elif input_data.sampling_method == SamplingMethod.WEIGHTED:
|
||||
if not input_data.weight_key:
|
||||
raise ValueError("Weight key must be provided for weighted sampling.")
|
||||
weights = []
|
||||
for item in data_to_sample:
|
||||
if isinstance(item, dict):
|
||||
weight = item.get(input_data.weight_key)
|
||||
elif hasattr(item, input_data.weight_key):
|
||||
weight = getattr(item, input_data.weight_key)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Weight key '{input_data.weight_key}' not found in item {item}"
|
||||
)
|
||||
|
||||
if weight is None:
|
||||
raise ValueError(
|
||||
f"Weight value for key '{input_data.weight_key}' is None"
|
||||
)
|
||||
try:
|
||||
weights.append(float(weight))
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Weight value '{weight}' cannot be converted to a number"
|
||||
)
|
||||
|
||||
if not weights:
|
||||
raise ValueError(
|
||||
f"No valid weights found using key '{input_data.weight_key}'"
|
||||
)
|
||||
|
||||
indices = random.choices(
|
||||
range(data_size), weights=weights, k=input_data.sample_size
|
||||
)
|
||||
elif input_data.sampling_method == SamplingMethod.RESERVOIR:
|
||||
indices = list(range(input_data.sample_size))
|
||||
for i in range(input_data.sample_size, data_size):
|
||||
j = random.randint(0, i)
|
||||
if j < input_data.sample_size:
|
||||
indices[j] = i
|
||||
elif input_data.sampling_method == SamplingMethod.CLUSTER:
|
||||
if not input_data.cluster_key:
|
||||
raise ValueError("Cluster key must be provided for cluster sampling.")
|
||||
clusters = defaultdict(list)
|
||||
for i, item in enumerate(data_to_sample):
|
||||
if isinstance(item, dict):
|
||||
cluster_value = item.get(input_data.cluster_key)
|
||||
elif hasattr(item, input_data.cluster_key):
|
||||
cluster_value = getattr(item, input_data.cluster_key)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Item {item} does not have the cluster key '{input_data.cluster_key}'"
|
||||
)
|
||||
|
||||
clusters[str(cluster_value)].append(i)
|
||||
|
||||
# Randomly select clusters until we have enough samples
|
||||
selected_clusters = []
|
||||
while (
|
||||
sum(len(clusters[c]) for c in selected_clusters)
|
||||
< input_data.sample_size
|
||||
):
|
||||
available_clusters = [c for c in clusters if c not in selected_clusters]
|
||||
if not available_clusters:
|
||||
break
|
||||
selected_clusters.append(random.choice(available_clusters))
|
||||
|
||||
for cluster in selected_clusters:
|
||||
indices.extend(clusters[cluster])
|
||||
|
||||
# If we have more samples than needed, randomly remove some
|
||||
if len(indices) > input_data.sample_size:
|
||||
indices = random.sample(indices, input_data.sample_size)
|
||||
else:
|
||||
raise ValueError(f"Unknown sampling method: {input_data.sampling_method}")
|
||||
|
||||
sampled_data = [data_to_sample[i] for i in indices]
|
||||
|
||||
# Clear accumulated data after sampling if accumulation is enabled
|
||||
if input_data.accumulate:
|
||||
self.accumulated_data = []
|
||||
|
||||
yield "sampled_data", sampled_data
|
||||
yield "sample_indices", indices
|
||||
263
rnd/autogpt_server/autogpt_server/data/credit.py
Normal file
263
rnd/autogpt_server/autogpt_server/data/credit.py
Normal file
@@ -0,0 +1,263 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import prisma.errors
|
||||
from prisma import Json
|
||||
from prisma.enums import UserBlockCreditType
|
||||
from prisma.models import UserBlockCredit
|
||||
from pydantic import BaseModel
|
||||
|
||||
from autogpt_server.blocks.llm import (
|
||||
MODEL_METADATA,
|
||||
AIConversationBlock,
|
||||
AIStructuredResponseGeneratorBlock,
|
||||
AITextGeneratorBlock,
|
||||
AITextSummarizerBlock,
|
||||
)
|
||||
from autogpt_server.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from autogpt_server.data.block import Block, BlockInput
|
||||
from autogpt_server.util.settings import Config
|
||||
|
||||
|
||||
class BlockCostType(str, Enum):
|
||||
RUN = "run" # cost X credits per run
|
||||
BYTE = "byte" # cost X credits per byte
|
||||
SECOND = "second" # cost X credits per second
|
||||
|
||||
|
||||
class BlockCost(BaseModel):
|
||||
cost_amount: int
|
||||
cost_filter: BlockInput
|
||||
cost_type: BlockCostType
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_amount: int,
|
||||
cost_type: BlockCostType = BlockCostType.RUN,
|
||||
cost_filter: Optional[BlockInput] = None,
|
||||
**data: Any,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
cost_amount=cost_amount,
|
||||
cost_filter=cost_filter or {},
|
||||
cost_type=cost_type,
|
||||
**data,
|
||||
)
|
||||
|
||||
|
||||
llm_cost = [
|
||||
BlockCost(
|
||||
cost_type=BlockCostType.RUN,
|
||||
cost_filter={
|
||||
"model": model,
|
||||
"api_key": None, # Running LLM with user own API key is free.
|
||||
},
|
||||
cost_amount=metadata.cost_factor,
|
||||
)
|
||||
for model, metadata in MODEL_METADATA.items()
|
||||
]
|
||||
|
||||
BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
AIConversationBlock: llm_cost,
|
||||
AITextGeneratorBlock: llm_cost,
|
||||
AIStructuredResponseGeneratorBlock: llm_cost,
|
||||
AITextSummarizerBlock: llm_cost,
|
||||
CreateTalkingAvatarVideoBlock: [
|
||||
BlockCost(cost_amount=15, cost_filter={"api_key": None})
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class UserCreditBase(ABC):
|
||||
def __init__(self, num_user_credits_refill: int):
|
||||
self.num_user_credits_refill = num_user_credits_refill
|
||||
|
||||
@abstractmethod
|
||||
async def get_or_refill_credit(self, user_id: str) -> int:
|
||||
"""
|
||||
Get the current credit for the user and refill if no transaction has been made in the current cycle.
|
||||
|
||||
Returns:
|
||||
int: The current credit for the user.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
user_credit: int,
|
||||
block: Block,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
) -> int:
|
||||
"""
|
||||
Spend the credits for the user based on the block usage.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
user_credit (int): The current credit for the user.
|
||||
block (Block): The block that is being used.
|
||||
input_data (BlockInput): The input data for the block.
|
||||
data_size (float): The size of the data being processed.
|
||||
run_time (float): The time taken to run the block.
|
||||
|
||||
Returns:
|
||||
int: amount of credit spent
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
"""
|
||||
Top up the credits for the user.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID.
|
||||
amount (int): The amount to top up.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class UserCredit(UserCreditBase):
|
||||
async def get_or_refill_credit(self, user_id: str) -> int:
|
||||
cur_time = self.time_now()
|
||||
cur_month = cur_time.replace(day=1, hour=0, minute=0, second=0, microsecond=0)
|
||||
nxt_month = cur_month.replace(month=cur_month.month + 1)
|
||||
|
||||
user_credit = await UserBlockCredit.prisma().group_by(
|
||||
by=["userId"],
|
||||
sum={"amount": True},
|
||||
where={
|
||||
"userId": user_id,
|
||||
"createdAt": {"gte": cur_month, "lt": nxt_month},
|
||||
"isActive": True,
|
||||
},
|
||||
)
|
||||
|
||||
if user_credit:
|
||||
credit_sum = user_credit[0].get("_sum") or {}
|
||||
return credit_sum.get("amount", 0)
|
||||
|
||||
key = f"MONTHLY-CREDIT-TOP-UP-{cur_month}"
|
||||
|
||||
try:
|
||||
await UserBlockCredit.prisma().create(
|
||||
data={
|
||||
"amount": self.num_user_credits_refill,
|
||||
"type": UserBlockCreditType.TOP_UP,
|
||||
"userId": user_id,
|
||||
"transactionKey": key,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
)
|
||||
except prisma.errors.UniqueViolationError:
|
||||
pass # Already refilled this month
|
||||
|
||||
return self.num_user_credits_refill
|
||||
|
||||
@staticmethod
|
||||
def time_now():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
def _block_usage_cost(
|
||||
block: Block,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
) -> tuple[int, BlockInput]:
|
||||
block_costs = BLOCK_COSTS.get(type(block))
|
||||
if not block_costs:
|
||||
return 0, {}
|
||||
|
||||
for block_cost in block_costs:
|
||||
if all(input_data.get(k) == b for k, b in block_cost.cost_filter.items()):
|
||||
if block_cost.cost_type == BlockCostType.RUN:
|
||||
return block_cost.cost_amount, block_cost.cost_filter
|
||||
|
||||
if block_cost.cost_type == BlockCostType.SECOND:
|
||||
return (
|
||||
int(run_time * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
if block_cost.cost_type == BlockCostType.BYTE:
|
||||
return (
|
||||
int(data_size * block_cost.cost_amount),
|
||||
block_cost.cost_filter,
|
||||
)
|
||||
|
||||
return 0, {}
|
||||
|
||||
async def spend_credits(
|
||||
self,
|
||||
user_id: str,
|
||||
user_credit: int,
|
||||
block: Block,
|
||||
input_data: BlockInput,
|
||||
data_size: float,
|
||||
run_time: float,
|
||||
validate_balance: bool = True,
|
||||
) -> int:
|
||||
cost, matching_filter = self._block_usage_cost(
|
||||
block=block, input_data=input_data, data_size=data_size, run_time=run_time
|
||||
)
|
||||
if cost <= 0:
|
||||
return 0
|
||||
|
||||
if validate_balance and user_credit < cost:
|
||||
raise ValueError(f"Insufficient credit: {user_credit} < {cost}")
|
||||
|
||||
await UserBlockCredit.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"amount": -cost,
|
||||
"type": UserBlockCreditType.USAGE,
|
||||
"blockId": block.id,
|
||||
"metadata": Json(
|
||||
{
|
||||
"block": block.name,
|
||||
"input": matching_filter,
|
||||
}
|
||||
),
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
)
|
||||
return cost
|
||||
|
||||
async def top_up_credits(self, user_id: str, amount: int):
|
||||
await UserBlockCredit.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"amount": amount,
|
||||
"type": UserBlockCreditType.TOP_UP,
|
||||
"createdAt": self.time_now(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class DisabledUserCredit(UserCreditBase):
|
||||
async def get_or_refill_credit(self, *args, **kwargs) -> int:
|
||||
return 0
|
||||
|
||||
async def spend_credits(self, *args, **kwargs) -> int:
|
||||
return 0
|
||||
|
||||
async def top_up_credits(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
def get_user_credit_model() -> UserCreditBase:
|
||||
config = Config()
|
||||
if config.enable_credit.lower() == "true":
|
||||
return UserCredit(config.num_user_credits_refill)
|
||||
else:
|
||||
return DisabledUserCredit(0)
|
||||
|
||||
|
||||
def get_block_costs() -> dict[str, list[BlockCost]]:
|
||||
return {block().id: costs for block, costs in BLOCK_COSTS.items()}
|
||||
@@ -1,9 +1,9 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
from prisma.models import (
|
||||
AgentGraphExecution,
|
||||
AgentNodeExecution,
|
||||
@@ -21,12 +21,14 @@ from autogpt_server.util import json, mock
|
||||
|
||||
|
||||
class GraphExecution(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
start_node_execs: list["NodeExecution"]
|
||||
graph_id: str
|
||||
start_node_execs: list["NodeExecution"]
|
||||
|
||||
|
||||
class NodeExecution(BaseModel):
|
||||
user_id: str
|
||||
graph_exec_id: str
|
||||
graph_id: str
|
||||
node_exec_id: str
|
||||
@@ -34,13 +36,7 @@ class NodeExecution(BaseModel):
|
||||
data: BlockInput
|
||||
|
||||
|
||||
class ExecutionStatus(str, Enum):
|
||||
INCOMPLETE = "INCOMPLETE"
|
||||
QUEUED = "QUEUED"
|
||||
RUNNING = "RUNNING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
|
||||
ExecutionStatus = AgentExecutionStatus
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -148,6 +144,7 @@ async def create_graph_execution(
|
||||
data={
|
||||
"agentGraphId": graph_id,
|
||||
"agentGraphVersion": graph_version,
|
||||
"executionStatus": ExecutionStatus.QUEUED,
|
||||
"AgentNodeExecutions": {
|
||||
"create": [ # type: ignore
|
||||
{
|
||||
@@ -259,10 +256,20 @@ async def upsert_execution_output(
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_start_time(graph_exec_id: str):
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
"executionStatus": ExecutionStatus.RUNNING,
|
||||
"startedAt": datetime.now(tz=timezone.utc),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def update_graph_execution_stats(graph_exec_id: str, stats: dict[str, Any]):
|
||||
await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={"stats": json.dumps(stats)},
|
||||
data={"executionStatus": ExecutionStatus.COMPLETED, "stats": json.dumps(stats)},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,8 +17,10 @@ if TYPE_CHECKING:
|
||||
from autogpt_server.blocks.basic import AgentInputBlock
|
||||
from autogpt_server.data import db
|
||||
from autogpt_server.data.block import Block, BlockData, BlockInput, get_block
|
||||
from autogpt_server.data.credit import get_user_credit_model
|
||||
from autogpt_server.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionResult,
|
||||
ExecutionStatus,
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
@@ -45,25 +47,41 @@ from autogpt_server.util.type import convert
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_log_metadata(
|
||||
graph_eid: str,
|
||||
graph_id: str,
|
||||
node_eid: str,
|
||||
node_id: str,
|
||||
block_name: str,
|
||||
) -> dict:
|
||||
return {
|
||||
"component": "ExecutionManager",
|
||||
"graph_eid": graph_eid,
|
||||
"graph_id": graph_id,
|
||||
"node_eid": node_eid,
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
class LogMetadata:
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_eid: str,
|
||||
graph_id: str,
|
||||
node_eid: str,
|
||||
node_id: str,
|
||||
block_name: str,
|
||||
):
|
||||
self.metadata = {
|
||||
"component": "ExecutionManager",
|
||||
"user_id": user_id,
|
||||
"graph_eid": graph_eid,
|
||||
"graph_id": graph_id,
|
||||
"node_eid": node_eid,
|
||||
"node_id": node_id,
|
||||
"block_name": block_name,
|
||||
}
|
||||
self.prefix = f"[ExecutionManager|uid:{user_id}|gid:{graph_id}|nid:{node_id}]|geid:{graph_eid}|nid:{node_eid}|{block_name}]"
|
||||
|
||||
def info(self, msg: str, **extra):
|
||||
logger.info(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
def get_log_prefix(graph_eid: str, node_eid: str, block_name: str = "-"):
|
||||
return f"[ExecutionManager][graph-eid-{graph_eid}|node-eid-{node_eid}|{block_name}]"
|
||||
def warning(self, msg: str, **extra):
|
||||
logger.warning(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
def error(self, msg: str, **extra):
|
||||
logger.error(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
def debug(self, msg: str, **extra):
|
||||
logger.debug(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
def exception(self, msg: str, **extra):
|
||||
logger.exception(msg, extra={"json_fields": {**self.metadata, **extra}})
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -89,6 +107,7 @@ def execute_node(
|
||||
Returns:
|
||||
The subsequent node to be enqueued, or None if there is no subsequent node.
|
||||
"""
|
||||
user_id = data.user_id
|
||||
graph_exec_id = data.graph_exec_id
|
||||
graph_id = data.graph_id
|
||||
node_exec_id = data.node_exec_id
|
||||
@@ -99,9 +118,10 @@ def execute_node(
|
||||
def wait(f: Coroutine[Any, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
|
||||
def update_execution(status: ExecutionStatus):
|
||||
def update_execution(status: ExecutionStatus) -> ExecutionResult:
|
||||
exec_update = wait(update_execution_status(node_exec_id, status))
|
||||
api_client.send_execution_update(exec_update.model_dump())
|
||||
return exec_update
|
||||
|
||||
node = wait(get_node(node_id))
|
||||
|
||||
@@ -111,43 +131,35 @@ def execute_node(
|
||||
return
|
||||
|
||||
# Sanity check: validate the execution input.
|
||||
log_metadata = get_log_metadata(
|
||||
log_metadata = LogMetadata(
|
||||
user_id=user_id,
|
||||
graph_eid=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
node_eid=node_exec_id,
|
||||
node_id=node_id,
|
||||
block_name=node_block.name,
|
||||
)
|
||||
prefix = get_log_prefix(
|
||||
graph_eid=graph_exec_id,
|
||||
node_eid=node_exec_id,
|
||||
block_name=node_block.name,
|
||||
)
|
||||
input_data, error = validate_exec(node, data.data, resolve_input=False)
|
||||
if input_data is None:
|
||||
logger.error(
|
||||
"{prefix} Skip execution, input validation error",
|
||||
extra={"json_fields": {**log_metadata, "error": error}},
|
||||
)
|
||||
log_metadata.error(f"Skip execution, input validation error: {error}")
|
||||
return
|
||||
|
||||
# Execute the node
|
||||
input_data_str = json.dumps(input_data)
|
||||
input_size = len(input_data_str)
|
||||
logger.info(
|
||||
f"{prefix} Executed node with input",
|
||||
extra={"json_fields": {**log_metadata, "input": input_data_str}},
|
||||
)
|
||||
log_metadata.info("Executed node with input", input=input_data_str)
|
||||
update_execution(ExecutionStatus.RUNNING)
|
||||
user_credit = get_user_credit_model()
|
||||
|
||||
output_size = 0
|
||||
try:
|
||||
credit = wait(user_credit.get_or_refill_credit(user_id))
|
||||
if credit < 0:
|
||||
raise ValueError(f"Insufficient credit: {credit}")
|
||||
|
||||
for output_name, output_data in node_block.execute(input_data):
|
||||
output_size += len(json.dumps(output_data))
|
||||
logger.info(
|
||||
f"{prefix} Node produced output",
|
||||
extra={"json_fields": {**log_metadata, output_name: output_data}},
|
||||
)
|
||||
log_metadata.info("Node produced output", output_name=output_data)
|
||||
wait(upsert_execution_output(node_exec_id, output_name, output_data))
|
||||
|
||||
for execution in _enqueue_next_nodes(
|
||||
@@ -155,20 +167,25 @@ def execute_node(
|
||||
loop=loop,
|
||||
node=node,
|
||||
output=(output_name, output_data),
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
log_metadata=log_metadata,
|
||||
):
|
||||
yield execution
|
||||
|
||||
update_execution(ExecutionStatus.COMPLETED)
|
||||
r = update_execution(ExecutionStatus.COMPLETED)
|
||||
s = input_size + output_size
|
||||
t = (
|
||||
(r.end_time - r.start_time).total_seconds()
|
||||
if r.end_time and r.start_time
|
||||
else 0
|
||||
)
|
||||
wait(user_credit.spend_credits(user_id, credit, node_block, input_data, s, t))
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{e.__class__.__name__}: {e}"
|
||||
logger.exception(
|
||||
f"{prefix} Node execution failed with error",
|
||||
extra={"json_fields": {**log_metadata, error: error_msg}},
|
||||
)
|
||||
error_msg = str(e)
|
||||
log_metadata.exception(f"Node execution failed with error {error_msg}")
|
||||
wait(upsert_execution_output(node_exec_id, "error", error_msg))
|
||||
update_execution(ExecutionStatus.FAILED)
|
||||
|
||||
@@ -194,9 +211,10 @@ def _enqueue_next_nodes(
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
node: Node,
|
||||
output: BlockData,
|
||||
user_id: str,
|
||||
graph_exec_id: str,
|
||||
graph_id: str,
|
||||
log_metadata: dict,
|
||||
log_metadata: LogMetadata,
|
||||
) -> list[NodeExecution]:
|
||||
def wait(f: Coroutine[Any, Any, T]) -> T:
|
||||
return loop.run_until_complete(f)
|
||||
@@ -209,6 +227,7 @@ def _enqueue_next_nodes(
|
||||
)
|
||||
api_client.send_execution_update(exec_update.model_dump())
|
||||
return NodeExecution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
graph_id=graph_id,
|
||||
node_exec_id=node_exec_id,
|
||||
@@ -262,17 +281,11 @@ def _enqueue_next_nodes(
|
||||
|
||||
# Incomplete input data, skip queueing the execution.
|
||||
if not next_node_input:
|
||||
logger.warning(
|
||||
f"Skipped queueing {suffix}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.warning(f"Skipped queueing {suffix}")
|
||||
return enqueued_executions
|
||||
|
||||
# Input is complete, enqueue the execution.
|
||||
logger.info(
|
||||
f"Enqueued {suffix}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Enqueued {suffix}")
|
||||
enqueued_executions.append(
|
||||
add_enqueued_execution(next_node_exec_id, next_node_id, next_node_input)
|
||||
)
|
||||
@@ -298,11 +311,9 @@ def _enqueue_next_nodes(
|
||||
idata, msg = validate_exec(next_node, idata)
|
||||
suffix = f"{next_output_name}>{next_input_name}~{ineid}:{msg}"
|
||||
if not idata:
|
||||
logger.info(
|
||||
f"{log_metadata} Enqueueing static-link skipped: {suffix}"
|
||||
)
|
||||
log_metadata.info(f"Enqueueing static-link skipped: {suffix}")
|
||||
continue
|
||||
logger.info(f"{log_metadata} Enqueueing static-link execution {suffix}")
|
||||
log_metadata.info(f"Enqueueing static-link execution {suffix}")
|
||||
enqueued_executions.append(
|
||||
add_enqueued_execution(iexec.node_exec_id, next_node_id, idata)
|
||||
)
|
||||
@@ -443,22 +454,18 @@ class Executor:
|
||||
def on_node_execution(
|
||||
cls, q: ExecutionQueue[NodeExecution], node_exec: NodeExecution
|
||||
):
|
||||
log_metadata = get_log_metadata(
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
prefix = get_log_prefix(
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
node_eid=node_exec.node_exec_id,
|
||||
block_name="-",
|
||||
)
|
||||
|
||||
execution_stats = {}
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, log_metadata, prefix, execution_stats
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
)
|
||||
execution_stats["walltime"] = timing_info.wall_time
|
||||
execution_stats["cputime"] = timing_info.cpu_time
|
||||
@@ -473,29 +480,19 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
node_exec: NodeExecution,
|
||||
log_metadata: dict,
|
||||
prefix: str,
|
||||
log_metadata: LogMetadata,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
try:
|
||||
logger.info(
|
||||
f"{prefix} Start node execution {node_exec.node_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
for execution in execute_node(
|
||||
cls.loop, cls.agent_server_client, node_exec, stats
|
||||
):
|
||||
q.add(execution)
|
||||
logger.info(
|
||||
f"{prefix} Finished node execution {node_exec.node_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Finished node execution {node_exec.node_exec_id}")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed node execution {node_exec.node_exec_id}: {e}",
|
||||
extra={
|
||||
**log_metadata,
|
||||
},
|
||||
log_metadata.exception(
|
||||
f"Failed node execution {node_exec.node_exec_id}: {e}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -517,10 +514,12 @@ class Executor:
|
||||
|
||||
@classmethod
|
||||
def on_graph_executor_stop(cls):
|
||||
logger.info(
|
||||
f"[on_graph_executor_stop {cls.pid}] ⏳ Terminating node executor pool..."
|
||||
)
|
||||
prefix = f"[on_graph_executor_stop {cls.pid}]"
|
||||
logger.info(f"{prefix} ⏳ Disconnecting DB...")
|
||||
cls.loop.run_until_complete(db.disconnect())
|
||||
logger.info(f"{prefix} ⏳ Terminating node executor pool...")
|
||||
cls.executor.terminate()
|
||||
logger.info(f"{prefix} ✅ Finished cleanup")
|
||||
|
||||
@classmethod
|
||||
def _init_node_executor_pool(cls):
|
||||
@@ -532,20 +531,16 @@ class Executor:
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_graph_execution(cls, graph_exec: GraphExecution, cancel: threading.Event):
|
||||
log_metadata = get_log_metadata(
|
||||
log_metadata = LogMetadata(
|
||||
user_id=graph_exec.user_id,
|
||||
graph_eid=graph_exec.graph_exec_id,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_id="*",
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
prefix = get_log_prefix(
|
||||
graph_eid=graph_exec.graph_exec_id,
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, node_count = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata, prefix
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
@@ -565,13 +560,9 @@ class Executor:
|
||||
cls,
|
||||
graph_exec: GraphExecution,
|
||||
cancel: threading.Event,
|
||||
log_metadata: dict,
|
||||
prefix: str,
|
||||
log_metadata: LogMetadata,
|
||||
) -> int:
|
||||
logger.info(
|
||||
f"{prefix} Start graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
|
||||
n_node_executions = 0
|
||||
finished = False
|
||||
|
||||
@@ -581,10 +572,7 @@ class Executor:
|
||||
if finished:
|
||||
return
|
||||
cls.executor.terminate()
|
||||
logger.info(
|
||||
f"{prefix} Terminated graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Terminated graph execution {graph_exec.graph_exec_id}")
|
||||
cls._init_node_executor_pool()
|
||||
|
||||
cancel_thread = threading.Thread(target=cancel_handler)
|
||||
@@ -622,10 +610,9 @@ class Executor:
|
||||
# Re-enqueueing the data back to the queue will disrupt the order.
|
||||
execution.wait()
|
||||
|
||||
logger.debug(
|
||||
f"{prefix} Dispatching node execution {exec_data.node_exec_id} "
|
||||
log_metadata.debug(
|
||||
f"Dispatching node execution {exec_data.node_exec_id} "
|
||||
f"for node {exec_data.node_id}",
|
||||
extra={**log_metadata},
|
||||
)
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
@@ -635,10 +622,8 @@ class Executor:
|
||||
|
||||
# Avoid terminating graph execution when some nodes are still running.
|
||||
while queue.empty() and running_executions:
|
||||
logger.debug(
|
||||
"Queue empty; running nodes: "
|
||||
f"{list(running_executions.keys())}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
log_metadata.debug(
|
||||
f"Queue empty; running nodes: {list(running_executions.keys())}"
|
||||
)
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
@@ -647,20 +632,13 @@ class Executor:
|
||||
if not queue.empty():
|
||||
break # yield to parent loop to execute new queue items
|
||||
|
||||
logger.debug(
|
||||
f"Waiting on execution of node {node_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.debug(f"Waiting on execution of node {node_id}")
|
||||
execution.wait(3)
|
||||
|
||||
logger.info(
|
||||
f"{prefix} Finished graph execution {graph_exec.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"{prefix} Failed graph execution {graph_exec.graph_exec_id}: {e}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
log_metadata.exception(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
if not cancel.is_set():
|
||||
@@ -747,6 +725,7 @@ class ExecutionManager(AppService):
|
||||
for node_exec in node_execs:
|
||||
starting_node_execs.append(
|
||||
NodeExecution(
|
||||
user_id=user_id,
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
@@ -762,6 +741,7 @@ class ExecutionManager(AppService):
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
graph_exec = GraphExecution(
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
graph_exec_id=graph_exec_id,
|
||||
start_node_execs=starting_node_execs,
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
|
||||
from autogpt_server.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
|
||||
class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Based on the documentation at:
|
||||
- [Authorizing OAuth apps - GitHub Docs](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authorizing-oauth-apps)
|
||||
- [Refreshing user access tokens - GitHub Docs](https://docs.github.com/en/apps/creating-github-apps/authenticating-with-a-github-app/refreshing-user-access-tokens)
|
||||
|
||||
Notes:
|
||||
- By default, token expiration is disabled on GitHub Apps. This means the access
|
||||
token doesn't expire and no refresh token is returned by the authorization flow.
|
||||
- When token expiration gets enabled, any existing tokens will remain non-expiring.
|
||||
- When token expiration gets disabled, token refreshes will return a non-expiring
|
||||
access token *with no refresh token*.
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "github"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://github.com/login/oauth/authorize"
|
||||
self.token_url = "https://github.com/login/oauth/access_token"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(scopes),
|
||||
"state": state,
|
||||
}
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
if not credentials.refresh_token:
|
||||
return credentials
|
||||
|
||||
return self._request_tokens(
|
||||
{
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
"grant_type": "refresh_token",
|
||||
}
|
||||
)
|
||||
|
||||
def _request_tokens(
|
||||
self,
|
||||
params: dict[str, str],
|
||||
current_credentials: Optional[OAuth2Credentials] = None,
|
||||
) -> OAuth2Credentials:
|
||||
request_body = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
**params,
|
||||
}
|
||||
headers = {"Accept": "application/json"}
|
||||
response = requests.post(self.token_url, data=request_body, headers=headers)
|
||||
response.raise_for_status()
|
||||
token_data: dict = response.json()
|
||||
|
||||
now = int(time.time())
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=current_credentials.title if current_credentials else "GitHub",
|
||||
access_token=token_data["access_token"],
|
||||
# Token refresh responses have an empty `scope` property (see docs),
|
||||
# so we have to get the scope from the existing credentials object.
|
||||
scopes=(
|
||||
token_data.get("scope", "").split(",")
|
||||
or (current_credentials.scopes if current_credentials else [])
|
||||
),
|
||||
# Refresh token and expiration intervals are only given if token expiration
|
||||
# is enabled in the GitHub App's settings.
|
||||
refresh_token=token_data.get("refresh_token"),
|
||||
access_token_expires_at=(
|
||||
now + expires_in
|
||||
if (expires_in := token_data.get("expires_in", None))
|
||||
else None
|
||||
),
|
||||
refresh_token_expires_at=(
|
||||
now + expires_in
|
||||
if (expires_in := token_data.get("refresh_token_expires_in", None))
|
||||
else None
|
||||
),
|
||||
)
|
||||
if current_credentials:
|
||||
new_credentials.id = current_credentials.id
|
||||
return new_credentials
|
||||
@@ -1,96 +0,0 @@
|
||||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import SecretStr
|
||||
|
||||
from .oauth import BaseOAuthHandler
|
||||
|
||||
|
||||
class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Based on the documentation at https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "google"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.token_uri = "https://oauth2.googleapis.com/token"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
flow = self._setup_oauth_flow(scopes)
|
||||
flow.redirect_uri = self.redirect_uri
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
include_granted_scopes="true",
|
||||
state=state,
|
||||
prompt="consent",
|
||||
)
|
||||
return authorization_url
|
||||
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
flow = self._setup_oauth_flow(None)
|
||||
flow.redirect_uri = self.redirect_uri
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
google_creds = flow.credentials
|
||||
# Google's OAuth library is poorly typed so we need some of these:
|
||||
assert google_creds.token
|
||||
assert google_creds.refresh_token
|
||||
assert google_creds.expiry
|
||||
assert google_creds.scopes
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title="Google",
|
||||
access_token=SecretStr(google_creds.token),
|
||||
refresh_token=SecretStr(google_creds.refresh_token),
|
||||
access_token_expires_at=int(google_creds.expiry.timestamp()),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=google_creds.scopes,
|
||||
)
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
# Google credentials should ALWAYS have a refresh token
|
||||
assert credentials.refresh_token
|
||||
|
||||
google_creds = Credentials(
|
||||
token=credentials.access_token.get_secret_value(),
|
||||
refresh_token=credentials.refresh_token.get_secret_value(),
|
||||
token_uri=self.token_uri,
|
||||
client_id=self.client_id,
|
||||
client_secret=self.client_secret,
|
||||
scopes=credentials.scopes,
|
||||
)
|
||||
# Google's OAuth library is poorly typed so we need some of these:
|
||||
assert google_creds.refresh_token
|
||||
assert google_creds.scopes
|
||||
|
||||
google_creds.refresh(Request())
|
||||
assert google_creds.expiry
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=credentials.title,
|
||||
access_token=SecretStr(google_creds.token),
|
||||
refresh_token=SecretStr(google_creds.refresh_token),
|
||||
access_token_expires_at=int(google_creds.expiry.timestamp()),
|
||||
refresh_token_expires_at=None,
|
||||
scopes=google_creds.scopes,
|
||||
)
|
||||
|
||||
def _setup_oauth_flow(self, scopes: list[str] | None) -> Flow:
|
||||
return Flow.from_client_config(
|
||||
{
|
||||
"web": {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": self.token_uri,
|
||||
}
|
||||
},
|
||||
scopes=scopes,
|
||||
)
|
||||
@@ -1,76 +0,0 @@
|
||||
from base64 import b64encode
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
|
||||
from autogpt_server.integrations.oauth import BaseOAuthHandler
|
||||
|
||||
|
||||
class NotionOAuthHandler(BaseOAuthHandler):
|
||||
"""
|
||||
Based on the documentation at https://developers.notion.com/docs/authorization
|
||||
|
||||
Notes:
|
||||
- Notion uses non-expiring access tokens and therefore doesn't have a refresh flow
|
||||
- Notion doesn't use scopes
|
||||
"""
|
||||
|
||||
PROVIDER_NAME = "notion"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
self.auth_base_url = "https://api.notion.com/v1/oauth/authorize"
|
||||
self.token_url = "https://api.notion.com/v1/oauth/token"
|
||||
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"response_type": "code",
|
||||
"owner": "user",
|
||||
"state": state,
|
||||
}
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
request_body = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
auth_str = b64encode(f"{self.client_id}:{self.client_secret}".encode()).decode()
|
||||
headers = {
|
||||
"Authorization": f"Basic {auth_str}",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
response = requests.post(self.token_url, json=request_body, headers=headers)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=token_data.get("workspace_name", "Notion"),
|
||||
access_token=token_data["access_token"],
|
||||
refresh_token=None,
|
||||
access_token_expires_at=None, # Notion tokens don't expire
|
||||
refresh_token_expires_at=None,
|
||||
scopes=[],
|
||||
metadata={
|
||||
"owner": token_data["owner"],
|
||||
"bot_id": token_data["bot_id"],
|
||||
"workspace_id": token_data["workspace_id"],
|
||||
"workspace_name": token_data.get("workspace_name"),
|
||||
"workspace_icon": token_data.get("workspace_icon"),
|
||||
},
|
||||
)
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
# Notion doesn't support token refresh
|
||||
return credentials
|
||||
|
||||
def needs_refresh(self, credentials: OAuth2Credentials) -> bool:
|
||||
# Notion access tokens don't expire
|
||||
return False
|
||||
@@ -1,48 +0,0 @@
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar
|
||||
|
||||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
|
||||
|
||||
class BaseOAuthHandler(ABC):
|
||||
PROVIDER_NAME: ClassVar[str]
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
"""Constructs a login URL that the user can be redirected to"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def exchange_code_for_tokens(self, code: str) -> OAuth2Credentials:
|
||||
"""Exchanges the acquired authorization code from login for a set of tokens"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
"""Implements the token refresh mechanism"""
|
||||
...
|
||||
|
||||
def refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
if credentials.provider != self.PROVIDER_NAME:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} can not refresh tokens "
|
||||
f"for other provider '{credentials.provider}'"
|
||||
)
|
||||
return self._refresh_tokens(credentials)
|
||||
|
||||
def get_access_token(self, credentials: OAuth2Credentials) -> str:
|
||||
"""Returns a valid access token, refreshing it first if needed"""
|
||||
if self.needs_refresh(credentials):
|
||||
credentials = self.refresh_tokens(credentials)
|
||||
return credentials.access_token.get_secret_value()
|
||||
|
||||
def needs_refresh(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""Indicates whether the given tokens need to be refreshed"""
|
||||
return (
|
||||
credentials.access_token_expires_at is not None
|
||||
and credentials.access_token_expires_at < int(time.time()) + 300
|
||||
)
|
||||
@@ -23,6 +23,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "github"
|
||||
EMAIL_ENDPOINT = "https://api.github.com/user/emails"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
@@ -69,10 +70,13 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
response.raise_for_status()
|
||||
token_data: dict = response.json()
|
||||
|
||||
username = self._request_username(token_data["access_token"])
|
||||
|
||||
now = int(time.time())
|
||||
new_credentials = OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=current_credentials.title if current_credentials else "GitHub",
|
||||
title=current_credentials.title if current_credentials else None,
|
||||
username=username,
|
||||
access_token=token_data["access_token"],
|
||||
# Token refresh responses have an empty `scope` property (see docs),
|
||||
# so we have to get the scope from the existing credentials object.
|
||||
@@ -97,3 +101,19 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
if current_credentials:
|
||||
new_credentials.id = current_credentials.id
|
||||
return new_credentials
|
||||
|
||||
def _request_username(self, access_token: str) -> str | None:
|
||||
url = "https://api.github.com/user"
|
||||
headers = {
|
||||
"Accept": "application/vnd.github+json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"X-GitHub-Api-Version": "2022-11-28",
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if not response.ok:
|
||||
return None
|
||||
|
||||
# Get the login (username)
|
||||
return response.json().get("login")
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from autogpt_libs.supabase_integration_credentials_store import OAuth2Credentials
|
||||
from google.auth.transport.requests import Request
|
||||
from google.auth.external_account_authorized_user import (
|
||||
Credentials as ExternalAccountCredentials,
|
||||
)
|
||||
from google.auth.transport.requests import AuthorizedSession, Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
from pydantic import SecretStr
|
||||
@@ -13,6 +16,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
""" # noqa
|
||||
|
||||
PROVIDER_NAME = "google"
|
||||
EMAIL_ENDPOINT = "https://www.googleapis.com/oauth2/v2/userinfo"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
@@ -37,6 +41,8 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
flow.fetch_token(code=code)
|
||||
|
||||
google_creds = flow.credentials
|
||||
username = self._request_email(google_creds)
|
||||
|
||||
# Google's OAuth library is poorly typed so we need some of these:
|
||||
assert google_creds.token
|
||||
assert google_creds.refresh_token
|
||||
@@ -44,7 +50,8 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
assert google_creds.scopes
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title="Google",
|
||||
title=None,
|
||||
username=username,
|
||||
access_token=SecretStr(google_creds.token),
|
||||
refresh_token=SecretStr(google_creds.refresh_token),
|
||||
access_token_expires_at=int(google_creds.expiry.timestamp()),
|
||||
@@ -52,6 +59,15 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
scopes=google_creds.scopes,
|
||||
)
|
||||
|
||||
def _request_email(
|
||||
self, creds: Credentials | ExternalAccountCredentials
|
||||
) -> str | None:
|
||||
session = AuthorizedSession(creds)
|
||||
response = session.get(self.EMAIL_ENDPOINT)
|
||||
if not response.ok:
|
||||
return None
|
||||
return response.json()["email"]
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
# Google credentials should ALWAYS have a refresh token
|
||||
assert credentials.refresh_token
|
||||
@@ -72,9 +88,10 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
assert google_creds.expiry
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
id=credentials.id,
|
||||
title=credentials.title,
|
||||
username=credentials.username,
|
||||
access_token=SecretStr(google_creds.token),
|
||||
refresh_token=SecretStr(google_creds.refresh_token),
|
||||
access_token_expires_at=int(google_creds.expiry.timestamp()),
|
||||
|
||||
@@ -49,10 +49,18 @@ class NotionOAuthHandler(BaseOAuthHandler):
|
||||
response = requests.post(self.token_url, json=request_body, headers=headers)
|
||||
response.raise_for_status()
|
||||
token_data = response.json()
|
||||
# Email is only available for non-bot users
|
||||
email = (
|
||||
token_data["owner"]["person"]["email"]
|
||||
if "person" in token_data["owner"]
|
||||
and "email" in token_data["owner"]["person"]
|
||||
else None
|
||||
)
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=token_data.get("workspace_name", "Notion"),
|
||||
title=token_data.get("workspace_name"),
|
||||
username=email,
|
||||
access_token=token_data["access_token"],
|
||||
refresh_token=None,
|
||||
access_token_expires_at=None, # Notion tokens don't expire
|
||||
|
||||
@@ -4,6 +4,10 @@ from typing import Annotated, Literal
|
||||
from autogpt_libs.supabase_integration_credentials_store import (
|
||||
SupabaseIntegrationCredentialsStore,
|
||||
)
|
||||
from autogpt_libs.supabase_integration_credentials_store.types import (
|
||||
Credentials,
|
||||
OAuth2Credentials,
|
||||
)
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Query, Request
|
||||
from pydantic import BaseModel
|
||||
from supabase import Client
|
||||
@@ -48,8 +52,11 @@ async def login(
|
||||
|
||||
|
||||
class CredentialsMetaResponse(BaseModel):
|
||||
credentials_id: str
|
||||
credentials_type: Literal["oauth2", "api_key"]
|
||||
id: str
|
||||
type: Literal["oauth2", "api_key"]
|
||||
title: str | None
|
||||
scopes: list[str] | None
|
||||
username: str | None
|
||||
|
||||
|
||||
@integrations_api_router.post("/{provider}/callback")
|
||||
@@ -73,13 +80,53 @@ async def callback(
|
||||
logger.warning(f"Code->Token exchange failed for provider {provider}: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# TODO: Allow specifying `title` to set on `credentials`
|
||||
store.add_creds(user_id, credentials)
|
||||
return CredentialsMetaResponse(
|
||||
credentials_id=credentials.id,
|
||||
credentials_type=credentials.type,
|
||||
id=credentials.id,
|
||||
type=credentials.type,
|
||||
title=credentials.title,
|
||||
scopes=credentials.scopes,
|
||||
username=credentials.username,
|
||||
)
|
||||
|
||||
|
||||
@integrations_api_router.get("/{provider}/credentials")
|
||||
async def list_credentials(
|
||||
provider: Annotated[str, Path(title="The provider to list credentials for")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
) -> list[CredentialsMetaResponse]:
|
||||
credentials = store.get_creds_by_provider(user_id, provider)
|
||||
return [
|
||||
CredentialsMetaResponse(
|
||||
id=cred.id,
|
||||
type=cred.type,
|
||||
title=cred.title,
|
||||
scopes=cred.scopes if isinstance(cred, OAuth2Credentials) else None,
|
||||
username=cred.username if isinstance(cred, OAuth2Credentials) else None,
|
||||
)
|
||||
for cred in credentials
|
||||
]
|
||||
|
||||
|
||||
@integrations_api_router.get("/{provider}/credentials/{cred_id}")
|
||||
async def get_credential(
|
||||
provider: Annotated[str, Path(title="The provider to retrieve credentials for")],
|
||||
cred_id: Annotated[str, Path(title="The ID of the credentials to retrieve")],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
store: Annotated[SupabaseIntegrationCredentialsStore, Depends(get_store)],
|
||||
) -> Credentials:
|
||||
credential = store.get_creds_by_id(user_id, cred_id)
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail="Credentials not found")
|
||||
if credential.provider != provider:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Credentials do not match the specified provider"
|
||||
)
|
||||
return credential
|
||||
|
||||
|
||||
# -------- UTILITIES --------- #
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from autogpt_server.data import execution as execution_db
|
||||
from autogpt_server.data import graph as graph_db
|
||||
from autogpt_server.data import user as user_db
|
||||
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
|
||||
from autogpt_server.data.credit import get_block_costs, get_user_credit_model
|
||||
from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue
|
||||
from autogpt_server.data.user import get_or_create_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
@@ -32,6 +33,7 @@ class AgentServer(AppService):
|
||||
mutex = KeyedMutex()
|
||||
use_redis = True
|
||||
_test_dependency_overrides = {}
|
||||
_user_credit_model = get_user_credit_model()
|
||||
|
||||
def __init__(self, event_queue: AsyncEventQueue | None = None):
|
||||
super().__init__(port=Config().agent_server_port)
|
||||
@@ -91,6 +93,11 @@ class AgentServer(AppService):
|
||||
endpoint=self.get_graph_blocks,
|
||||
methods=["GET"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/blocks/costs",
|
||||
endpoint=self.get_graph_block_costs,
|
||||
methods=["GET"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/blocks/{block_id}/execute",
|
||||
endpoint=self.execute_graph_block,
|
||||
@@ -196,6 +203,11 @@ class AgentServer(AppService):
|
||||
endpoint=self.update_schedule,
|
||||
methods=["PUT"],
|
||||
)
|
||||
api_router.add_api_route(
|
||||
path="/credits",
|
||||
endpoint=self.get_user_credits,
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
api_router.add_api_route(
|
||||
path="/settings",
|
||||
@@ -265,6 +277,10 @@ class AgentServer(AppService):
|
||||
def get_graph_blocks(cls) -> list[dict[Any, Any]]:
|
||||
return [v.to_dict() for v in block.get_blocks().values()]
|
||||
|
||||
@classmethod
|
||||
def get_graph_block_costs(cls) -> dict[Any, Any]:
|
||||
return get_block_costs()
|
||||
|
||||
@classmethod
|
||||
def execute_graph_block(
|
||||
cls, block_id: str, data: BlockInput
|
||||
@@ -481,6 +497,25 @@ class AgentServer(AppService):
|
||||
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
|
||||
@classmethod
|
||||
async def get_graph_run_status(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> execution_db.ExecutionStatus:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
execution = await execution_db.get_graph_execution(graph_exec_id, user_id)
|
||||
if not execution:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Execution #{graph_exec_id} not found."
|
||||
)
|
||||
|
||||
return execution.executionStatus
|
||||
|
||||
@classmethod
|
||||
async def get_graph_run_node_execution_results(
|
||||
cls,
|
||||
@@ -522,6 +557,11 @@ class AgentServer(AppService):
|
||||
execution_scheduler.update_schedule(schedule_id, is_enabled, user_id=user_id)
|
||||
return {"id": schedule_id}
|
||||
|
||||
async def get_user_credits(
|
||||
self, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> dict[str, int]:
|
||||
return {"credits": await self._user_credit_model.get_or_refill_credit(user_id)}
|
||||
|
||||
def get_execution_schedules(
|
||||
self, graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> dict[str, str]:
|
||||
|
||||
@@ -252,7 +252,6 @@ Here are a couple of sample of the Block class implementation:
|
||||
|
||||
async def block_autogen_agent():
|
||||
async with SpinTestServer() as server:
|
||||
test_manager = server.exec_manager
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"input": "Write me a block that writes a string into a file."}
|
||||
@@ -261,10 +260,8 @@ async def block_autogen_agent():
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
exec_manager=test_manager,
|
||||
graph_id=test_graph.id,
|
||||
graph_exec_id=response["id"],
|
||||
num_execs=10,
|
||||
timeout=1200,
|
||||
user_id=test_user.id,
|
||||
)
|
||||
|
||||
@@ -153,7 +153,6 @@ async def create_test_user() -> User:
|
||||
|
||||
async def reddit_marketing_agent():
|
||||
async with SpinTestServer() as server:
|
||||
exec_man = server.exec_manager
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), user_id=test_user.id)
|
||||
input_data = {"subreddit": "AutoGPT"}
|
||||
@@ -161,9 +160,7 @@ async def reddit_marketing_agent():
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
exec_man, test_user.id, test_graph.id, response["id"], 13, 120
|
||||
)
|
||||
result = await wait_execution(test_user.id, test_graph.id, response["id"], 120)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
@@ -75,7 +75,6 @@ def create_test_graph() -> graph.Graph:
|
||||
|
||||
async def sample_agent():
|
||||
async with SpinTestServer() as server:
|
||||
exec_man = server.exec_manager
|
||||
test_user = await create_test_user()
|
||||
test_graph = await create_graph(create_test_graph(), test_user.id)
|
||||
input_data = {"input_1": "Hello", "input_2": "World"}
|
||||
@@ -83,9 +82,7 @@ async def sample_agent():
|
||||
test_graph.id, input_data, test_user.id
|
||||
)
|
||||
print(response)
|
||||
result = await wait_execution(
|
||||
exec_man, test_user.id, test_graph.id, response["id"], 4, 10
|
||||
)
|
||||
result = await wait_execution(test_user.id, test_graph.id, response["id"], 10)
|
||||
print(result)
|
||||
|
||||
|
||||
|
||||
@@ -42,15 +42,15 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
"""Config for the server."""
|
||||
|
||||
num_graph_workers: int = Field(
|
||||
default=1,
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for graph execution.",
|
||||
)
|
||||
num_node_workers: int = Field(
|
||||
default=1,
|
||||
default=5,
|
||||
ge=1,
|
||||
le=100,
|
||||
le=1000,
|
||||
description="Maximum number of workers to use for node execution within a single graph.",
|
||||
)
|
||||
pyro_host: str = Field(
|
||||
@@ -61,6 +61,14 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default="false",
|
||||
description="If authentication is enabled or not",
|
||||
)
|
||||
enable_credit: str = Field(
|
||||
default="false",
|
||||
description="If user credit system is enabled or not",
|
||||
)
|
||||
num_user_credits_refill: int = Field(
|
||||
default=1500,
|
||||
description="Number of credits to refill for each user",
|
||||
)
|
||||
# Add more configuration fields as needed
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
|
||||
@@ -5,6 +5,7 @@ from autogpt_server.data import db
|
||||
from autogpt_server.data.block import Block, initialize_blocks
|
||||
from autogpt_server.data.execution import ExecutionResult, ExecutionStatus
|
||||
from autogpt_server.data.queue import AsyncEventQueue
|
||||
from autogpt_server.data.user import create_default_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.server.rest_api import get_user_id
|
||||
@@ -64,6 +65,7 @@ class SpinTestServer:
|
||||
|
||||
await db.connect()
|
||||
await initialize_blocks()
|
||||
await create_default_user("false")
|
||||
|
||||
return self
|
||||
|
||||
@@ -82,25 +84,18 @@ class SpinTestServer:
|
||||
|
||||
|
||||
async def wait_execution(
|
||||
exec_manager: ExecutionManager,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
num_execs: int,
|
||||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_graph_run_node_execution_results(
|
||||
status = await AgentServer().get_graph_run_status(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
return (
|
||||
exec_manager.queue.empty()
|
||||
and len(execs) == num_execs
|
||||
and all(
|
||||
v.status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]
|
||||
for v in execs
|
||||
)
|
||||
)
|
||||
if status == ExecutionStatus.FAILED:
|
||||
raise Exception("Execution failed")
|
||||
return status == ExecutionStatus.COMPLETED
|
||||
|
||||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
{
|
||||
"num_graph_workers": 10,
|
||||
"num_node_workers": 5
|
||||
"num_node_workers": 5,
|
||||
"num_user_credits_refill": 1500
|
||||
}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
Warnings:
|
||||
|
||||
- The `executionStatus` column on the `AgentNodeExecution` table would be dropped and recreated. This will lead to data loss if there is data in the column.
|
||||
|
||||
*/
|
||||
-- CreateEnum
|
||||
CREATE TYPE "AgentExecutionStatus" AS ENUM ('INCOMPLETE', 'QUEUED', 'RUNNING', 'COMPLETED', 'FAILED');
|
||||
|
||||
-- CreateEnum
|
||||
CREATE TYPE "UserBlockCreditType" AS ENUM ('TOP_UP', 'USAGE');
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentGraphExecution" ADD COLUMN "executionStatus" "AgentExecutionStatus" NOT NULL DEFAULT 'COMPLETED',
|
||||
ADD COLUMN "startedAt" TIMESTAMP(3);
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "AgentNodeExecution" DROP COLUMN "executionStatus",
|
||||
ADD COLUMN "executionStatus" "AgentExecutionStatus" NOT NULL DEFAULT 'COMPLETED';
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "UserBlockCredit" (
|
||||
"transactionKey" TEXT NOT NULL,
|
||||
"createdAt" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"userId" TEXT NOT NULL,
|
||||
"blockId" TEXT,
|
||||
"amount" INTEGER NOT NULL,
|
||||
"type" "UserBlockCreditType" NOT NULL,
|
||||
"isActive" BOOLEAN NOT NULL DEFAULT true,
|
||||
"metadata" JSONB,
|
||||
|
||||
CONSTRAINT "UserBlockCredit_pkey" PRIMARY KEY ("transactionKey","userId")
|
||||
);
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserBlockCredit" ADD CONSTRAINT "UserBlockCredit_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE RESTRICT ON UPDATE CASCADE;
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "UserBlockCredit" ADD CONSTRAINT "UserBlockCredit_blockId_fkey" FOREIGN KEY ("blockId") REFERENCES "AgentBlock"("id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
6
rnd/autogpt_server/poetry.lock
generated
6
rnd/autogpt_server/poetry.lock
generated
@@ -289,7 +289,7 @@ description = "Shared libraries across NextGen AutoGPT"
|
||||
optional = false
|
||||
python-versions = ">=3.10,<4.0"
|
||||
files = []
|
||||
develop = false
|
||||
develop = true
|
||||
|
||||
[package.dependencies]
|
||||
colorama = "^0.4.6"
|
||||
@@ -2022,6 +2022,7 @@ description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
|
||||
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
|
||||
]
|
||||
|
||||
@@ -2032,6 +2033,7 @@ description = "A collection of ASN.1-based protocols modules"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"},
|
||||
{file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"},
|
||||
]
|
||||
|
||||
@@ -3621,4 +3623,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "fbc928c40dc95041f7750ab34677fa3eebacd06a84944de900dedd639f847a9c"
|
||||
content-hash = "311c527a1d1947af049dac27c7a2b2f49d7fa4cdede52ef436422a528b0ad866"
|
||||
|
||||
@@ -13,7 +13,7 @@ python = "^3.10"
|
||||
aio-pika = "^9.4.3"
|
||||
anthropic = "^0.25.1"
|
||||
apscheduler = "^3.10.4"
|
||||
autogpt-libs = { path = "../autogpt_libs" }
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
click = "^8.1.7"
|
||||
croniter = "^2.0.5"
|
||||
discord-py = "^2.4.0"
|
||||
|
||||
@@ -22,6 +22,7 @@ model User {
|
||||
AgentGraphs AgentGraph[]
|
||||
AgentGraphExecutions AgentGraphExecution[]
|
||||
AgentGraphExecutionSchedules AgentGraphExecutionSchedule[]
|
||||
UserBlockCredit UserBlockCredit[]
|
||||
|
||||
@@index([id])
|
||||
@@index([email])
|
||||
@@ -29,9 +30,9 @@ model User {
|
||||
|
||||
// This model describes the Agent Graph/Flow (Multi Agent System).
|
||||
model AgentGraph {
|
||||
id String @default(uuid())
|
||||
version Int @default(1)
|
||||
createdAt DateTime @default(now())
|
||||
id String @default(uuid())
|
||||
version Int @default(1)
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
name String?
|
||||
@@ -111,13 +112,26 @@ model AgentBlock {
|
||||
|
||||
// Prisma requires explicit back-references.
|
||||
ReferencedByAgentNode AgentNode[]
|
||||
UserBlockCredit UserBlockCredit[]
|
||||
}
|
||||
|
||||
// This model describes the status of an AgentGraphExecution or AgentNodeExecution.
|
||||
enum AgentExecutionStatus {
|
||||
INCOMPLETE
|
||||
QUEUED
|
||||
RUNNING
|
||||
COMPLETED
|
||||
FAILED
|
||||
}
|
||||
|
||||
// This model describes the execution of an AgentGraph.
|
||||
model AgentGraphExecution {
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
id String @id @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
startedAt DateTime?
|
||||
|
||||
executionStatus AgentExecutionStatus @default(COMPLETED)
|
||||
|
||||
agentGraphId String
|
||||
agentGraphVersion Int @default(1)
|
||||
@@ -145,12 +159,10 @@ model AgentNodeExecution {
|
||||
Input AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionInput")
|
||||
Output AgentNodeExecutionInputOutput[] @relation("AgentNodeExecutionOutput")
|
||||
|
||||
// sqlite does not support enum
|
||||
// enum Status { INCOMPLETE, QUEUED, RUNNING, SUCCESS, FAILED }
|
||||
executionStatus String
|
||||
executionStatus AgentExecutionStatus @default(COMPLETED)
|
||||
// Final JSON serialized input data for the node execution.
|
||||
executionData String?
|
||||
addedTime DateTime @default(now())
|
||||
addedTime DateTime @default(now())
|
||||
queuedTime DateTime?
|
||||
startedTime DateTime?
|
||||
endedTime DateTime?
|
||||
@@ -178,8 +190,8 @@ model AgentNodeExecutionInputOutput {
|
||||
|
||||
// This model describes the recurring execution schedule of an Agent.
|
||||
model AgentGraphExecutionSchedule {
|
||||
id String @id
|
||||
createdAt DateTime @default(now())
|
||||
id String @id
|
||||
createdAt DateTime @default(now())
|
||||
updatedAt DateTime? @updatedAt
|
||||
|
||||
agentGraphId String
|
||||
@@ -199,3 +211,27 @@ model AgentGraphExecutionSchedule {
|
||||
|
||||
@@index([isEnabled])
|
||||
}
|
||||
|
||||
enum UserBlockCreditType {
|
||||
TOP_UP
|
||||
USAGE
|
||||
}
|
||||
|
||||
model UserBlockCredit {
|
||||
transactionKey String @default(uuid())
|
||||
createdAt DateTime @default(now())
|
||||
|
||||
userId String
|
||||
user User @relation(fields: [userId], references: [id])
|
||||
|
||||
blockId String?
|
||||
block AgentBlock? @relation(fields: [blockId], references: [id])
|
||||
|
||||
amount Int
|
||||
type UserBlockCreditType
|
||||
|
||||
isActive Boolean @default(true)
|
||||
metadata Json?
|
||||
|
||||
@@id(name: "creditTransactionIdentifier", [transactionKey, userId])
|
||||
}
|
||||
|
||||
90
rnd/autogpt_server/test/data/test_credit.py
Normal file
90
rnd/autogpt_server/test/data/test_credit.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from prisma.models import UserBlockCredit
|
||||
|
||||
from autogpt_server.blocks.llm import AITextGeneratorBlock
|
||||
from autogpt_server.data.credit import UserCredit
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID
|
||||
from autogpt_server.util.test import SpinTestServer
|
||||
|
||||
REFILL_VALUE = 1000
|
||||
user_credit = UserCredit(REFILL_VALUE)
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_block_credit_usage(server: SpinTestServer):
|
||||
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
|
||||
spending_amount_1 = await user_credit.spend_credits(
|
||||
DEFAULT_USER_ID,
|
||||
current_credit,
|
||||
AITextGeneratorBlock(),
|
||||
{"model": "gpt-4-turbo"},
|
||||
0.0,
|
||||
0.0,
|
||||
validate_balance=False,
|
||||
)
|
||||
assert spending_amount_1 > 0
|
||||
|
||||
spending_amount_2 = await user_credit.spend_credits(
|
||||
DEFAULT_USER_ID,
|
||||
current_credit,
|
||||
AITextGeneratorBlock(),
|
||||
{"model": "gpt-4-turbo", "api_key": "owned_api_key"},
|
||||
0.0,
|
||||
0.0,
|
||||
validate_balance=False,
|
||||
)
|
||||
assert spending_amount_2 == 0
|
||||
|
||||
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
assert new_credit == current_credit - spending_amount_1 - spending_amount_2
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_block_credit_top_up(server: SpinTestServer):
|
||||
current_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
|
||||
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
||||
|
||||
new_credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
assert new_credit == current_credit + 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_block_credit_reset(server: SpinTestServer):
|
||||
month1 = datetime(2022, 1, 15)
|
||||
month2 = datetime(2022, 2, 15)
|
||||
|
||||
user_credit.time_now = lambda: month2
|
||||
month2credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
|
||||
# Month 1 result should only affect month 1
|
||||
user_credit.time_now = lambda: month1
|
||||
month1credit = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
await user_credit.top_up_credits(DEFAULT_USER_ID, 100)
|
||||
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month1credit + 100
|
||||
|
||||
# Month 2 balance is unaffected
|
||||
user_credit.time_now = lambda: month2
|
||||
assert await user_credit.get_or_refill_credit(DEFAULT_USER_ID) == month2credit
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="session")
|
||||
async def test_credit_refill(server: SpinTestServer):
|
||||
# Clear all transactions within the month
|
||||
await UserBlockCredit.prisma().update_many(
|
||||
where={
|
||||
"userId": DEFAULT_USER_ID,
|
||||
"createdAt": {
|
||||
"gte": datetime(2022, 2, 1),
|
||||
"lt": datetime(2022, 3, 1),
|
||||
},
|
||||
},
|
||||
data={"isActive": False},
|
||||
)
|
||||
user_credit.time_now = lambda: datetime(2022, 2, 15)
|
||||
|
||||
balance = await user_credit.get_or_refill_credit(DEFAULT_USER_ID)
|
||||
assert balance == REFILL_VALUE
|
||||
@@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from autogpt_server.blocks.basic import AgentInputBlock, StoreValueBlock
|
||||
from autogpt_server.data.graph import Graph, Link, Node
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID, create_default_user
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID
|
||||
from autogpt_server.server.model import CreateGraph
|
||||
from autogpt_server.util.test import SpinTestServer
|
||||
|
||||
@@ -22,8 +22,6 @@ async def test_graph_creation(server: SpinTestServer):
|
||||
Args:
|
||||
server (SpinTestServer): The test server instance.
|
||||
"""
|
||||
await create_default_user("false")
|
||||
|
||||
value_block = StoreValueBlock().id
|
||||
input_block = AgentInputBlock().id
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from prisma.models import User
|
||||
from autogpt_server.blocks.basic import FindInDictionaryBlock, StoreValueBlock
|
||||
from autogpt_server.blocks.maths import CalculatorBlock, Operation
|
||||
from autogpt_server.data import execution, graph
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.usecases.sample import create_test_graph, create_test_user
|
||||
from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
@@ -12,7 +11,6 @@ from autogpt_server.util.test import SpinTestServer, wait_execution
|
||||
|
||||
async def execute_graph(
|
||||
agent_server: AgentServer,
|
||||
test_manager: ExecutionManager,
|
||||
test_graph: graph.Graph,
|
||||
test_user: User,
|
||||
input_data: dict,
|
||||
@@ -23,9 +21,8 @@ async def execute_graph(
|
||||
graph_exec_id = response["id"]
|
||||
|
||||
# Execution queue should be empty
|
||||
assert await wait_execution(
|
||||
test_manager, test_user.id, test_graph.id, graph_exec_id, num_execs
|
||||
)
|
||||
result = await wait_execution(test_user.id, test_graph.id, graph_exec_id)
|
||||
assert result and len(result) == num_execs
|
||||
return graph_exec_id
|
||||
|
||||
|
||||
@@ -108,7 +105,6 @@ async def test_agent_execution(server: SpinTestServer):
|
||||
data = {"input_1": "Hello", "input_2": "World"}
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server,
|
||||
server.exec_manager,
|
||||
test_graph,
|
||||
test_user,
|
||||
data,
|
||||
@@ -169,7 +165,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
|
||||
server.agent_server, test_graph, test_user, {}, 3
|
||||
)
|
||||
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
@@ -250,7 +246,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
||||
test_user = await create_test_user()
|
||||
test_graph = await graph.create_graph(test_graph, user_id=test_user.id)
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
|
||||
server.agent_server, test_graph, test_user, {}, 8
|
||||
)
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
|
||||
@@ -17,6 +17,10 @@ ENV POETRY_VERSION=1.8.3 \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false \
|
||||
PATH="$POETRY_HOME/bin:$PATH"
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
RUN pip3 install poetry
|
||||
|
||||
# Copy and install dependencies
|
||||
@@ -35,6 +39,9 @@ FROM python:3.11-slim-buster AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
# Copy only necessary files from builder
|
||||
COPY --from=builder /app /app
|
||||
COPY --from=builder /usr/local/lib/python3.11 /usr/local/lib/python3.11
|
||||
|
||||
@@ -353,7 +353,9 @@ async def search_db(
|
||||
|
||||
|
||||
async def get_top_agents_by_downloads(
|
||||
page: int = 1, page_size: int = 10
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
submission_status: prisma.enums.SubmissionStatus = prisma.enums.SubmissionStatus.APPROVED,
|
||||
) -> TopAgentsDBResponse:
|
||||
"""Retrieve the top agents by download count.
|
||||
|
||||
@@ -374,6 +376,7 @@ async def get_top_agents_by_downloads(
|
||||
analytics = await prisma.models.AnalyticsTracker.prisma().find_many(
|
||||
include={"agent": True},
|
||||
order={"downloads": "desc"},
|
||||
where={"agent": {"is": {"submissionStatus": submission_status}}},
|
||||
skip=skip,
|
||||
take=page_size,
|
||||
)
|
||||
@@ -441,7 +444,10 @@ async def set_agent_featured(
|
||||
|
||||
|
||||
async def get_featured_agents(
|
||||
category: str = "featured", page: int = 1, page_size: int = 10
|
||||
category: str = "featured",
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
submission_status: prisma.enums.SubmissionStatus = prisma.enums.SubmissionStatus.APPROVED,
|
||||
) -> FeaturedAgentResponse:
|
||||
"""Retrieve a list of featured agents from the database based on the provided category.
|
||||
|
||||
@@ -463,6 +469,7 @@ async def get_featured_agents(
|
||||
where={
|
||||
"featuredCategories": {"has": category},
|
||||
"isActive": True,
|
||||
"agent": {"is": {"submissionStatus": submission_status}},
|
||||
},
|
||||
include={"agent": {"include": {"AnalyticsTracker": True}}},
|
||||
skip=skip,
|
||||
|
||||
@@ -5,6 +5,7 @@ import typing
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
import prisma
|
||||
import prisma.enums
|
||||
|
||||
import market.db
|
||||
import market.model
|
||||
@@ -38,6 +39,10 @@ async def list_agents(
|
||||
sort_order: typing.Literal["asc", "desc"] = fastapi.Query(
|
||||
"desc", description="Sort order (asc or desc)"
|
||||
),
|
||||
submission_status: prisma.enums.SubmissionStatus = fastapi.Query(
|
||||
default=prisma.enums.SubmissionStatus.APPROVED,
|
||||
description="Filter by submission status",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Retrieve a list of agents based on the provided filters.
|
||||
@@ -52,6 +57,7 @@ async def list_agents(
|
||||
description_threshold (int): Fuzzy search threshold (default: 60, min: 0, max: 100).
|
||||
sort_by (str): Field to sort by (default: "createdAt").
|
||||
sort_order (str): Sort order (asc or desc) (default: "desc").
|
||||
submission_status (str): Filter by submission status (default: "APPROVED").
|
||||
|
||||
Returns:
|
||||
market.model.AgentListResponse: A response containing the list of agents and pagination information.
|
||||
@@ -70,6 +76,7 @@ async def list_agents(
|
||||
description_threshold=description_threshold,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
submission_status=submission_status,
|
||||
)
|
||||
|
||||
agents = [
|
||||
@@ -210,6 +217,10 @@ async def top_agents_by_downloads(
|
||||
page_size: int = fastapi.Query(
|
||||
10, ge=1, le=100, description="Number of items per page"
|
||||
),
|
||||
submission_status: prisma.enums.SubmissionStatus = fastapi.Query(
|
||||
default=prisma.enums.SubmissionStatus.APPROVED,
|
||||
description="Filter by submission status",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Retrieve a list of top agents based on the number of downloads.
|
||||
@@ -217,6 +228,7 @@ async def top_agents_by_downloads(
|
||||
Args:
|
||||
page (int): Page number (default: 1).
|
||||
page_size (int): Number of items per page (default: 10, min: 1, max: 100).
|
||||
submission_status (str): Filter by submission status (default: "APPROVED").
|
||||
|
||||
Returns:
|
||||
market.model.AgentListResponse: A response containing the list of top agents and pagination information.
|
||||
@@ -228,6 +240,7 @@ async def top_agents_by_downloads(
|
||||
result = await market.db.get_top_agents_by_downloads(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
submission_status=submission_status,
|
||||
)
|
||||
|
||||
ret = market.model.AgentListResponse(
|
||||
@@ -274,6 +287,10 @@ async def get_featured_agents(
|
||||
page_size: int = fastapi.Query(
|
||||
10, ge=1, le=100, description="Number of items per page"
|
||||
),
|
||||
submission_status: prisma.enums.SubmissionStatus = fastapi.Query(
|
||||
default=prisma.enums.SubmissionStatus.APPROVED,
|
||||
description="Filter by submission status",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Retrieve a list of featured agents based on the provided category.
|
||||
@@ -282,6 +299,7 @@ async def get_featured_agents(
|
||||
category (str): Category of featured agents (default: "featured").
|
||||
page (int): Page number (default: 1).
|
||||
page_size (int): Number of items per page (default: 10, min: 1, max: 100).
|
||||
submission_status (str): Filter by submission status (default: "APPROVED").
|
||||
|
||||
Returns:
|
||||
market.model.AgentListResponse: A response containing the list of featured agents and pagination information.
|
||||
@@ -294,6 +312,7 @@ async def get_featured_agents(
|
||||
category=category,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
submission_status=submission_status,
|
||||
)
|
||||
|
||||
ret = market.model.AgentListResponse(
|
||||
|
||||
Reference in New Issue
Block a user