mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 23:28:07 -05:00
Merge branch 'master' into aarushikansal/execution-manager
This commit is contained in:
@@ -36,5 +36,3 @@ rnd/autogpt_builder/.env.example
|
||||
rnd/autogpt_builder/.env.local
|
||||
rnd/autogpt_server/.env
|
||||
rnd/autogpt_server/.venv/
|
||||
|
||||
|
||||
|
||||
9
.github/workflows/autogpt-server-ci.yml
vendored
9
.github/workflows/autogpt-server-ci.yml
vendored
@@ -128,9 +128,14 @@ jobs:
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
poetry run pytest -vv \
|
||||
test
|
||||
if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
poetry run pytest -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
else
|
||||
poetry run pytest -vv test
|
||||
fi
|
||||
if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
env:
|
||||
LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
|
||||
@@ -29,6 +29,7 @@
|
||||
"@radix-ui/react-tooltip": "^1.1.2",
|
||||
"@supabase/ssr": "^0.4.0",
|
||||
"@supabase/supabase-js": "^2.45.0",
|
||||
"@tanstack/react-table": "^8.20.5",
|
||||
"@xyflow/react": "^12.1.0",
|
||||
"ajv": "^8.17.1",
|
||||
"class-variance-authority": "^0.7.0",
|
||||
@@ -47,8 +48,8 @@
|
||||
"react-icons": "^5.2.1",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-modal": "^3.16.1",
|
||||
"recharts": "^2.12.7",
|
||||
"react-shepherd": "^6.1.1",
|
||||
"recharts": "^2.12.7",
|
||||
"tailwind-merge": "^2.3.0",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"uuid": "^10.0.0",
|
||||
|
||||
@@ -2,15 +2,18 @@ import { withRoleAccess } from "@/lib/withRoleAccess";
|
||||
|
||||
import React from "react";
|
||||
import { getReviewableAgents } from "@/components/admin/marketplace/actions";
|
||||
import AdminMarketplaceCard from "@/components/admin/marketplace/AdminMarketplaceCard";
|
||||
import AdminMarketplaceAgentList from "@/components/admin/marketplace/AdminMarketplaceAgentList";
|
||||
import AdminFeaturedAgentsControl from "@/components/admin/marketplace/AdminFeaturedAgentsControl";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
async function AdminMarketplace() {
|
||||
const agents = await getReviewableAgents();
|
||||
const reviewableAgents = await getReviewableAgents();
|
||||
|
||||
return (
|
||||
<div>
|
||||
<h3>Agents to review</h3>
|
||||
<AdminMarketplaceAgentList agents={agents.agents} />
|
||||
</div>
|
||||
<>
|
||||
<AdminMarketplaceAgentList agents={reviewableAgents.agents} />
|
||||
<Separator className="my-4" />
|
||||
<AdminFeaturedAgentsControl className="mt-4" />
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import { Suspense } from "react";
|
||||
import { notFound } from "next/navigation";
|
||||
import MarketplaceAPI from "@/lib/marketplace-api";
|
||||
import { AgentDetailResponse } from "@/lib/marketplace-api";
|
||||
import AgentDetailContent from "@/components/AgentDetailContent";
|
||||
import AgentDetailContent from "@/components/marketplace/AgentDetailContent";
|
||||
|
||||
async function getAgentDetails(id: string): Promise<AgentDetailResponse> {
|
||||
const apiUrl =
|
||||
|
||||
@@ -90,10 +90,11 @@ const Monitor = () => {
|
||||
<FlowRunsList
|
||||
className={column2}
|
||||
flows={flows}
|
||||
runs={(selectedFlow
|
||||
? flowRuns.filter((v) => v.graphID == selectedFlow.id)
|
||||
: flowRuns
|
||||
).toSorted((a, b) => Number(a.startTime) - Number(b.startTime))}
|
||||
runs={[
|
||||
...(selectedFlow
|
||||
? flowRuns.filter((v) => v.graphID == selectedFlow.id)
|
||||
: flowRuns),
|
||||
].sort((a, b) => Number(a.startTime) - Number(b.startTime))}
|
||||
selectedRun={selectedRun}
|
||||
onSelectRun={(r) => setSelectedRun(r.id == selectedRun?.id ? null : r)}
|
||||
/>
|
||||
|
||||
@@ -34,10 +34,16 @@ import ConnectionLine from "./ConnectionLine";
|
||||
import { Control, ControlPanel } from "@/components/edit/control/ControlPanel";
|
||||
import { SaveControl } from "@/components/edit/control/SaveControl";
|
||||
import { BlocksControl } from "@/components/edit/control/BlocksControl";
|
||||
import { IconPlay, IconRedo2, IconUndo2 } from "@/components/ui/icons";
|
||||
import {
|
||||
IconPlay,
|
||||
IconRedo2,
|
||||
IconSquare,
|
||||
IconUndo2,
|
||||
} from "@/components/ui/icons";
|
||||
import { startTutorial } from "./tutorial";
|
||||
import useAgentGraph from "@/hooks/useAgentGraph";
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { useRouter, usePathname, useSearchParams } from "next/navigation";
|
||||
|
||||
// This is for the history, this is the minimum distance a block must move before it is logged
|
||||
// It helps to prevent spamming the history with small movements especially when pressing on a input in a block
|
||||
@@ -74,13 +80,17 @@ const FlowEditor: React.FC<{
|
||||
availableNodes,
|
||||
getOutputType,
|
||||
requestSave,
|
||||
requestSaveRun,
|
||||
requestSaveAndRun,
|
||||
requestStopRun,
|
||||
isRunning,
|
||||
nodes,
|
||||
setNodes,
|
||||
edges,
|
||||
setEdges,
|
||||
} = useAgentGraph(flowID, template, visualizeBeads !== "no");
|
||||
|
||||
const router = useRouter();
|
||||
const pathname = usePathname();
|
||||
const initialPositionRef = useRef<{
|
||||
[key: string]: { x: number; y: number };
|
||||
}>({});
|
||||
@@ -97,7 +107,7 @@ const FlowEditor: React.FC<{
|
||||
// If resetting tutorial
|
||||
if (params.get("resetTutorial") === "true") {
|
||||
localStorage.removeItem("shepherd-tour"); // Clear tutorial flag
|
||||
window.location.href = window.location.pathname; // Redirect to clear URL parameters
|
||||
router.push(pathname);
|
||||
} else {
|
||||
// Otherwise, start tutorial if conditions are met
|
||||
const shouldStartTutorial = !localStorage.getItem("shepherd-tour");
|
||||
@@ -539,9 +549,9 @@ const FlowEditor: React.FC<{
|
||||
onClick: handleRedo,
|
||||
},
|
||||
{
|
||||
label: "Run",
|
||||
icon: <IconPlay />,
|
||||
onClick: requestSaveRun,
|
||||
label: !isRunning ? "Run" : "Stop",
|
||||
icon: !isRunning ? <IconPlay /> : <IconSquare />,
|
||||
onClick: !isRunning ? requestSaveAndRun : requestStopRun,
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
@@ -0,0 +1,149 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogClose,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
MultiSelector,
|
||||
MultiSelectorContent,
|
||||
MultiSelectorInput,
|
||||
MultiSelectorItem,
|
||||
MultiSelectorList,
|
||||
MultiSelectorTrigger,
|
||||
} from "@/components/ui/multiselect";
|
||||
import { Controller, useForm } from "react-hook-form";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useState } from "react";
|
||||
import { addFeaturedAgent } from "./actions";
|
||||
import { Agent } from "@/lib/marketplace-api/types";
|
||||
|
||||
type FormData = {
|
||||
agent: string;
|
||||
categories: string[];
|
||||
};
|
||||
|
||||
export const AdminAddFeaturedAgentDialog = ({
|
||||
categories,
|
||||
agents,
|
||||
}: {
|
||||
categories: string[];
|
||||
agents: Agent[];
|
||||
}) => {
|
||||
const [selectedAgent, setSelectedAgent] = useState<string>("");
|
||||
const [selectedCategories, setSelectedCategories] = useState<string[]>([]);
|
||||
|
||||
const {
|
||||
control,
|
||||
handleSubmit,
|
||||
watch,
|
||||
setValue,
|
||||
formState: { errors },
|
||||
} = useForm<FormData>({
|
||||
defaultValues: {
|
||||
agent: "",
|
||||
categories: [],
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>
|
||||
<Button variant="outline" size="sm">
|
||||
Add Featured Agent
|
||||
</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add Featured Agent</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Controller
|
||||
name="agent"
|
||||
control={control}
|
||||
rules={{ required: true }}
|
||||
render={({ field }) => (
|
||||
<div>
|
||||
<label htmlFor={field.name}>Agent</label>
|
||||
<Select
|
||||
onValueChange={(value) => {
|
||||
field.onChange(value);
|
||||
setSelectedAgent(value);
|
||||
}}
|
||||
value={field.value || ""}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select an agent" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{/* Populate with agents */}
|
||||
{agents.map((agent) => (
|
||||
<SelectItem key={agent.id} value={agent.id}>
|
||||
{agent.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
<Controller
|
||||
name="categories"
|
||||
control={control}
|
||||
render={({ field }) => (
|
||||
<MultiSelector
|
||||
values={field.value || []}
|
||||
onValuesChange={(values) => {
|
||||
field.onChange(values);
|
||||
setSelectedCategories(values);
|
||||
}}
|
||||
>
|
||||
<MultiSelectorTrigger>
|
||||
<MultiSelectorInput placeholder="Select categories" />
|
||||
</MultiSelectorTrigger>
|
||||
<MultiSelectorContent>
|
||||
<MultiSelectorList>
|
||||
{categories.map((category) => (
|
||||
<MultiSelectorItem key={category} value={category}>
|
||||
{category}
|
||||
</MultiSelectorItem>
|
||||
))}
|
||||
</MultiSelectorList>
|
||||
</MultiSelectorContent>
|
||||
</MultiSelector>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<DialogFooter>
|
||||
<DialogClose asChild>
|
||||
<Button variant="outline">Cancel</Button>
|
||||
</DialogClose>
|
||||
<DialogClose asChild>
|
||||
<Button
|
||||
type="submit"
|
||||
onClick={async () => {
|
||||
// Handle adding the featured agent
|
||||
await addFeaturedAgent(selectedAgent, selectedCategories);
|
||||
// close the dialog
|
||||
}}
|
||||
>
|
||||
Add
|
||||
</Button>
|
||||
</DialogClose>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
@@ -0,0 +1,67 @@
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
getFeaturedAgents,
|
||||
removeFeaturedAgent,
|
||||
getCategories,
|
||||
getNotFeaturedAgents,
|
||||
} from "./actions";
|
||||
|
||||
import FeaturedAgentsTable from "./FeaturedAgentsTable";
|
||||
import { AdminAddFeaturedAgentDialog } from "./AdminAddFeaturedAgentDialog";
|
||||
import { revalidatePath } from "next/cache";
|
||||
|
||||
export default async function AdminFeaturedAgentsControl({
|
||||
className,
|
||||
}: {
|
||||
className?: string;
|
||||
}) {
|
||||
// add featured agent button
|
||||
// modal to select agent?
|
||||
// modal to select categories?
|
||||
// table of featured agents
|
||||
// in table
|
||||
// remove featured agent button
|
||||
// edit featured agent categories button
|
||||
// table footer
|
||||
// Next page button
|
||||
// Previous page button
|
||||
// Page number input
|
||||
// Page size input
|
||||
// Total pages input
|
||||
// Go to page button
|
||||
|
||||
const page = 1;
|
||||
const pageSize = 10;
|
||||
|
||||
const agents = await getFeaturedAgents(page, pageSize);
|
||||
|
||||
const categories = await getCategories();
|
||||
|
||||
const notFeaturedAgents = await getNotFeaturedAgents();
|
||||
|
||||
return (
|
||||
<div className={`flex flex-col gap-4 ${className}`}>
|
||||
<div className="mb-4 flex justify-between">
|
||||
<h3 className="text-lg font-semibold">Featured Agent Controls</h3>
|
||||
<AdminAddFeaturedAgentDialog
|
||||
categories={categories.unique_categories}
|
||||
agents={notFeaturedAgents.agents}
|
||||
/>
|
||||
</div>
|
||||
<FeaturedAgentsTable
|
||||
agents={agents.agents}
|
||||
globalActions={[
|
||||
{
|
||||
component: <Button>Remove</Button>,
|
||||
action: async (rows) => {
|
||||
"use server";
|
||||
const all = rows.map((row) => removeFeaturedAgent(row.id));
|
||||
await Promise.all(all);
|
||||
revalidatePath("/marketplace");
|
||||
},
|
||||
},
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -4,23 +4,33 @@ import { ClipboardX } from "lucide-react";
|
||||
|
||||
export default function AdminMarketplaceAgentList({
|
||||
agents,
|
||||
className,
|
||||
}: {
|
||||
agents: Agent[];
|
||||
className?: string;
|
||||
}) {
|
||||
if (agents.length === 0) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center py-12 text-gray-500">
|
||||
<ClipboardX size={48} />
|
||||
<p className="mt-4 text-lg font-semibold">No agents to review</p>
|
||||
<div className={className}>
|
||||
<h3 className="text-lg font-semibold">Agents to review</h3>
|
||||
<div className="flex flex-col items-center justify-center py-12 text-gray-500">
|
||||
<ClipboardX size={48} />
|
||||
<p className="mt-4 text-lg font-semibold">No agents to review</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
{agents.map((agent) => (
|
||||
<AdminMarketplaceCard agent={agent} key={agent.id} />
|
||||
))}
|
||||
<div className={`flex flex-col gap-4 ${className}`}>
|
||||
<div>
|
||||
<h3 className="text-lg font-semibold">Agents to review</h3>
|
||||
</div>
|
||||
<div className="flex flex-col gap-4">
|
||||
{agents.map((agent) => (
|
||||
<AdminMarketplaceCard agent={agent} key={agent.id} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Checkbox } from "@/components/ui/checkbox";
|
||||
import { DataTable } from "@/components/ui/data-table";
|
||||
import { Agent } from "@/lib/marketplace-api";
|
||||
import { ColumnDef } from "@tanstack/react-table";
|
||||
import { ArrowUpDown } from "lucide-react";
|
||||
import { removeFeaturedAgent } from "./actions";
|
||||
import { GlobalActions } from "@/components/ui/data-table";
|
||||
|
||||
export const columns: ColumnDef<Agent>[] = [
|
||||
{
|
||||
id: "select",
|
||||
header: ({ table }) => (
|
||||
<Checkbox
|
||||
checked={
|
||||
table.getIsAllPageRowsSelected() ||
|
||||
(table.getIsSomePageRowsSelected() && "indeterminate")
|
||||
}
|
||||
onCheckedChange={(value) => table.toggleAllPageRowsSelected(!!value)}
|
||||
aria-label="Select all"
|
||||
/>
|
||||
),
|
||||
cell: ({ row }) => (
|
||||
<Checkbox
|
||||
checked={row.getIsSelected()}
|
||||
onCheckedChange={(value) => row.toggleSelected(!!value)}
|
||||
aria-label="Select row"
|
||||
/>
|
||||
),
|
||||
},
|
||||
{
|
||||
header: ({ column }) => {
|
||||
return (
|
||||
<Button
|
||||
variant="ghost"
|
||||
onClick={() => column.toggleSorting(column.getIsSorted() === "asc")}
|
||||
>
|
||||
Name
|
||||
<ArrowUpDown className="ml-2 h-4 w-4" />
|
||||
</Button>
|
||||
);
|
||||
},
|
||||
accessorKey: "name",
|
||||
},
|
||||
{
|
||||
header: "Description",
|
||||
accessorKey: "description",
|
||||
},
|
||||
{
|
||||
header: "Categories",
|
||||
accessorKey: "categories",
|
||||
},
|
||||
{
|
||||
header: "Keywords",
|
||||
accessorKey: "keywords",
|
||||
},
|
||||
{
|
||||
header: "Downloads",
|
||||
accessorKey: "downloads",
|
||||
},
|
||||
{
|
||||
header: "Author",
|
||||
accessorKey: "author",
|
||||
},
|
||||
{
|
||||
header: "Version",
|
||||
accessorKey: "version",
|
||||
},
|
||||
{
|
||||
header: "actions",
|
||||
cell: ({ row }) => {
|
||||
const handleRemove = async () => {
|
||||
await removeFeaturedAgentWithId();
|
||||
};
|
||||
// const handleEdit = async () => {
|
||||
// console.log("edit");
|
||||
// };
|
||||
const removeFeaturedAgentWithId = removeFeaturedAgent.bind(
|
||||
null,
|
||||
row.original.id,
|
||||
);
|
||||
return (
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button variant="outline" size="sm" onClick={handleRemove}>
|
||||
Remove
|
||||
</Button>
|
||||
{/* <Button variant="outline" size="sm" onClick={handleEdit}>
|
||||
Edit
|
||||
</Button> */}
|
||||
</div>
|
||||
);
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
export default function FeaturedAgentsTable({
|
||||
agents,
|
||||
globalActions,
|
||||
}: {
|
||||
agents: Agent[];
|
||||
globalActions: GlobalActions<Agent>[];
|
||||
}) {
|
||||
return (
|
||||
<DataTable
|
||||
columns={columns}
|
||||
data={agents}
|
||||
filterPlaceholder="Search agents..."
|
||||
filterColumn="name"
|
||||
globalActions={globalActions}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
"use server";
|
||||
import AutoGPTServerAPI from "@/lib/autogpt-server-api";
|
||||
import MarketplaceAPI from "@/lib/marketplace-api";
|
||||
import { revalidatePath } from "next/cache";
|
||||
|
||||
export async function approveAgent(
|
||||
agentId: string,
|
||||
@@ -9,6 +11,7 @@ export async function approveAgent(
|
||||
const api = new MarketplaceAPI();
|
||||
await api.approveAgentSubmission(agentId, version, comment);
|
||||
console.debug(`Approving agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
}
|
||||
|
||||
export async function rejectAgent(
|
||||
@@ -19,9 +22,64 @@ export async function rejectAgent(
|
||||
const api = new MarketplaceAPI();
|
||||
await api.rejectAgentSubmission(agentId, version, comment);
|
||||
console.debug(`Rejecting agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
}
|
||||
|
||||
export async function getReviewableAgents() {
|
||||
const api = new MarketplaceAPI();
|
||||
return api.getAgentSubmissions();
|
||||
}
|
||||
|
||||
export async function getFeaturedAgents(
|
||||
page: number = 1,
|
||||
pageSize: number = 10,
|
||||
) {
|
||||
const api = new MarketplaceAPI();
|
||||
const featured = await api.getFeaturedAgents(page, pageSize);
|
||||
console.debug(`Getting featured agents ${featured.agents.length}`);
|
||||
return featured;
|
||||
}
|
||||
|
||||
export async function getFeaturedAgent(agentId: string) {
|
||||
const api = new MarketplaceAPI();
|
||||
const featured = await api.getFeaturedAgent(agentId);
|
||||
console.debug(`Getting featured agent ${featured.agentId}`);
|
||||
return featured;
|
||||
}
|
||||
|
||||
export async function addFeaturedAgent(
|
||||
agentId: string,
|
||||
categories: string[] = ["featured"],
|
||||
) {
|
||||
const api = new MarketplaceAPI();
|
||||
await api.addFeaturedAgent(agentId, categories);
|
||||
console.debug(`Adding featured agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
}
|
||||
|
||||
export async function removeFeaturedAgent(
|
||||
agentId: string,
|
||||
categories: string[] = ["featured"],
|
||||
) {
|
||||
const api = new MarketplaceAPI();
|
||||
await api.removeFeaturedAgent(agentId, categories);
|
||||
console.debug(`Removing featured agent ${agentId}`);
|
||||
revalidatePath("/marketplace");
|
||||
}
|
||||
|
||||
export async function getCategories() {
|
||||
const api = new MarketplaceAPI();
|
||||
const categories = await api.getCategories();
|
||||
console.debug(`Getting categories ${categories.unique_categories.length}`);
|
||||
return categories;
|
||||
}
|
||||
|
||||
export async function getNotFeaturedAgents(
|
||||
page: number = 1,
|
||||
pageSize: number = 100,
|
||||
) {
|
||||
const api = new MarketplaceAPI();
|
||||
const agents = await api.getNotFeaturedAgents(page, pageSize);
|
||||
console.debug(`Getting not featured agents ${agents.agents.length}`);
|
||||
return agents;
|
||||
}
|
||||
|
||||
@@ -21,8 +21,13 @@ import AutoGPTServerAPI, {
|
||||
import { cn } from "@/lib/utils";
|
||||
import { EnterIcon } from "@radix-ui/react-icons";
|
||||
|
||||
// Add this custom schema for File type
|
||||
const fileSchema = z.custom<File>((val) => val instanceof File, {
|
||||
message: "Must be a File object",
|
||||
});
|
||||
|
||||
const formSchema = z.object({
|
||||
agentFile: z.instanceof(File),
|
||||
agentFile: fileSchema,
|
||||
agentName: z.string().min(1, "Agent name is required"),
|
||||
agentDescription: z.string(),
|
||||
importAsTemplate: z.boolean(),
|
||||
|
||||
@@ -17,6 +17,11 @@ import { IconToyBrick } from "@/components/ui/icons";
|
||||
import SchemaTooltip from "@/components/SchemaTooltip";
|
||||
import { getPrimaryCategoryColor } from "@/lib/utils";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
|
||||
interface BlocksControlProps {
|
||||
blocks: Block[];
|
||||
@@ -60,17 +65,20 @@ export const BlocksControl: React.FC<BlocksControlProps> = ({
|
||||
|
||||
return (
|
||||
<Popover open={pinBlocksPopover ? true : undefined}>
|
||||
{" "}
|
||||
{/* Control popover open state */}
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
data-id="blocks-control-popover-trigger"
|
||||
>
|
||||
<IconToyBrick />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<Tooltip delayDuration={500}>
|
||||
<TooltipTrigger asChild>
|
||||
<PopoverTrigger asChild>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
data-id="blocks-control-popover-trigger"
|
||||
>
|
||||
<IconToyBrick />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="right">Blocks</TooltipContent>
|
||||
</Tooltip>
|
||||
<PopoverContent
|
||||
side="right"
|
||||
sideOffset={22}
|
||||
|
||||
@@ -10,6 +10,11 @@ import { Button } from "@/components/ui/button";
|
||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { IconSave } from "@/components/ui/icons";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
|
||||
interface SaveControlProps {
|
||||
agentMeta: GraphMeta | null;
|
||||
@@ -51,11 +56,16 @@ export const SaveControl = ({
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="icon">
|
||||
<IconSave />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
<Tooltip delayDuration={500}>
|
||||
<TooltipTrigger asChild>
|
||||
<PopoverTrigger asChild>
|
||||
<Button variant="ghost" size="icon">
|
||||
<IconSave />
|
||||
</Button>
|
||||
</PopoverTrigger>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="right">Save</TooltipContent>
|
||||
</Tooltip>
|
||||
<PopoverContent side="right" sideOffset={15} align="start">
|
||||
<Card className="border-none shadow-none">
|
||||
<CardContent className="p-4">
|
||||
|
||||
@@ -29,6 +29,7 @@ import {
|
||||
} from "@/components/ui/table";
|
||||
import moment from "moment/moment";
|
||||
import { FlowRun } from "@/lib/types";
|
||||
import { DialogTitle } from "@/components/ui/dialog";
|
||||
|
||||
export const AgentFlowList = ({
|
||||
flows,
|
||||
@@ -102,8 +103,11 @@ export const AgentFlowList = ({
|
||||
</DropdownMenu>
|
||||
|
||||
<DialogContent>
|
||||
<DialogHeader className="text-lg">
|
||||
Import an Agent (template) from a file
|
||||
<DialogHeader>
|
||||
<DialogTitle className="sr-only">Import Agent</DialogTitle>
|
||||
<h2 className="text-lg font-semibold">
|
||||
Import an Agent (template) from a file
|
||||
</h2>
|
||||
</DialogHeader>
|
||||
<AgentImportForm />
|
||||
</DialogContent>
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import React from "react";
|
||||
import { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import React, { useCallback } from "react";
|
||||
import AutoGPTServerAPI, { GraphMeta } from "@/lib/autogpt-server-api";
|
||||
import { FlowRun } from "@/lib/types";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import Link from "next/link";
|
||||
import { buttonVariants } from "@/components/ui/button";
|
||||
import { Button, buttonVariants } from "@/components/ui/button";
|
||||
import { IconSquare } from "@/components/ui/icons";
|
||||
import { Pencil2Icon } from "@radix-ui/react-icons";
|
||||
import moment from "moment/moment";
|
||||
import { FlowRunStatusBadge } from "@/components/monitor/FlowRunStatusBadge";
|
||||
@@ -20,6 +21,11 @@ export const FlowRunInfo: React.FC<
|
||||
);
|
||||
}
|
||||
|
||||
const handleStopRun = useCallback(() => {
|
||||
const api = new AutoGPTServerAPI();
|
||||
api.stopGraphExecution(flow.id, flowRun.id);
|
||||
}, [flow.id, flowRun.id]);
|
||||
|
||||
return (
|
||||
<Card {...props}>
|
||||
<CardHeader className="flex-row items-center justify-between space-x-3 space-y-0">
|
||||
@@ -34,12 +40,19 @@ export const FlowRunInfo: React.FC<
|
||||
Run ID: <code>{flowRun.id}</code>
|
||||
</p>
|
||||
</div>
|
||||
<Link
|
||||
className={buttonVariants({ variant: "outline" })}
|
||||
href={`/build?flowID=${flow.id}`}
|
||||
>
|
||||
<Pencil2Icon className="mr-2" /> Edit Agent
|
||||
</Link>
|
||||
<div className="flex space-x-2">
|
||||
{flowRun.status === "running" && (
|
||||
<Button onClick={handleStopRun} variant="destructive">
|
||||
<IconSquare className="mr-2" /> Stop Run
|
||||
</Button>
|
||||
)}
|
||||
<Link
|
||||
className={buttonVariants({ variant: "outline" })}
|
||||
href={`/build?flowID=${flow.id}`}
|
||||
>
|
||||
<Pencil2Icon className="mr-2" /> Edit Agent
|
||||
</Link>
|
||||
</div>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p>
|
||||
|
||||
209
rnd/autogpt_builder/src/components/ui/data-table.tsx
Normal file
209
rnd/autogpt_builder/src/components/ui/data-table.tsx
Normal file
@@ -0,0 +1,209 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
ColumnDef,
|
||||
ColumnFiltersState,
|
||||
SortingState,
|
||||
VisibilityState,
|
||||
flexRender,
|
||||
getCoreRowModel,
|
||||
useReactTable,
|
||||
getPaginationRowModel,
|
||||
getSortedRowModel,
|
||||
getFilteredRowModel,
|
||||
} from "@tanstack/react-table";
|
||||
import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
TableHead,
|
||||
TableBody,
|
||||
TableCell,
|
||||
Table,
|
||||
} from "@/components/ui/table";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuCheckboxItem,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { cloneElement, Fragment, useState } from "react";
|
||||
|
||||
export interface GlobalActions<TData> {
|
||||
component: React.ReactElement;
|
||||
action: (rows: TData[]) => Promise<void>;
|
||||
}
|
||||
|
||||
interface DataTableProps<TData, TValue> {
|
||||
columns: ColumnDef<TData, TValue>[];
|
||||
data: TData[];
|
||||
filterPlaceholder: string;
|
||||
filterColumn?: string;
|
||||
globalActions?: GlobalActions<TData>[];
|
||||
}
|
||||
|
||||
export function DataTable<TData, TValue>({
|
||||
columns,
|
||||
data,
|
||||
filterPlaceholder = "Filter...",
|
||||
filterColumn,
|
||||
globalActions = [],
|
||||
}: DataTableProps<TData, TValue>) {
|
||||
const [sorting, setSorting] = useState<SortingState>([]);
|
||||
const [columnFilters, setColumnFilters] = useState<ColumnFiltersState>([]);
|
||||
const [columnVisibility, setColumnVisibility] = useState<VisibilityState>({});
|
||||
const [rowSelection, setRowSelection] = useState({});
|
||||
|
||||
const table = useReactTable({
|
||||
data,
|
||||
columns,
|
||||
getCoreRowModel: getCoreRowModel(),
|
||||
getPaginationRowModel: getPaginationRowModel(),
|
||||
getSortedRowModel: getSortedRowModel(),
|
||||
getFilteredRowModel: getFilteredRowModel(),
|
||||
onSortingChange: setSorting,
|
||||
onColumnFiltersChange: setColumnFilters,
|
||||
onColumnVisibilityChange: setColumnVisibility,
|
||||
onRowSelectionChange: setRowSelection,
|
||||
state: {
|
||||
sorting,
|
||||
columnFilters,
|
||||
columnVisibility,
|
||||
rowSelection,
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex items-center gap-2 py-4">
|
||||
{filterColumn && (
|
||||
<Input
|
||||
placeholder={filterPlaceholder}
|
||||
value={
|
||||
(table.getColumn(filterColumn)?.getFilterValue() as string) ?? ""
|
||||
}
|
||||
onChange={(event) =>
|
||||
table.getColumn(filterColumn)?.setFilterValue(event.target.value)
|
||||
}
|
||||
className="max-w-sm"
|
||||
/>
|
||||
)}
|
||||
|
||||
{globalActions &&
|
||||
globalActions.map((action, index) => {
|
||||
return (
|
||||
<Fragment key={index}>
|
||||
<div className="flex items-center">
|
||||
{cloneElement(action.component, {
|
||||
onClick: () => {
|
||||
const filteredSelectedRows = table
|
||||
.getFilteredSelectedRowModel()
|
||||
.rows.map((row) => row.original);
|
||||
action.action(filteredSelectedRows);
|
||||
},
|
||||
})}
|
||||
</div>
|
||||
</Fragment>
|
||||
);
|
||||
})}
|
||||
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button variant="outline" className="ml-auto">
|
||||
Columns
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end">
|
||||
{table
|
||||
.getAllColumns()
|
||||
.filter((column) => column.getCanHide())
|
||||
.map((column) => {
|
||||
return (
|
||||
<DropdownMenuCheckboxItem
|
||||
key={column.id}
|
||||
className="capitalize"
|
||||
checked={column.getIsVisible()}
|
||||
onCheckedChange={(value) =>
|
||||
column.toggleVisibility(!!value)
|
||||
}
|
||||
>
|
||||
{column.id}
|
||||
</DropdownMenuCheckboxItem>
|
||||
);
|
||||
})}
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
|
||||
<div className="rounded-md border">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
{table.getHeaderGroups().map((headerGroup) => (
|
||||
<TableRow key={headerGroup.id}>
|
||||
{headerGroup.headers.map((header) => {
|
||||
return (
|
||||
<TableHead key={header.id}>
|
||||
{header.isPlaceholder
|
||||
? null
|
||||
: flexRender(
|
||||
header.column.columnDef.header,
|
||||
header.getContext(),
|
||||
)}
|
||||
</TableHead>
|
||||
);
|
||||
})}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{table.getRowModel().rows?.length ? (
|
||||
table.getRowModel().rows.map((row) => (
|
||||
<TableRow
|
||||
key={row.id}
|
||||
data-state={row.getIsSelected() && "selected"}
|
||||
>
|
||||
{row.getVisibleCells().map((cell) => (
|
||||
<TableCell key={cell.id}>
|
||||
{flexRender(
|
||||
cell.column.columnDef.cell,
|
||||
cell.getContext(),
|
||||
)}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))
|
||||
) : (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={columns.length}
|
||||
className="h-24 text-center"
|
||||
>
|
||||
No results.
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
<div className="flex items-center justify-end space-x-2 py-4">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => table.previousPage()}
|
||||
disabled={!table.getCanPreviousPage()}
|
||||
>
|
||||
Previous
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => table.nextPage()}
|
||||
disabled={!table.getCanNextPage()}
|
||||
>
|
||||
Next
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -405,6 +405,40 @@ export const IconPlay = createIcon((props) => (
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* Square icon component.
|
||||
*
|
||||
* @component IconSquare
|
||||
* @param {IconProps} props - The props object containing additional attributes and event handlers for the icon.
|
||||
* @returns {JSX.Element} - The square icon.
|
||||
*
|
||||
* @example
|
||||
* // Default usage this is the standard usage
|
||||
* <IconSquare />
|
||||
*
|
||||
* @example
|
||||
* // With custom color and size these should be used sparingly and only when necessary
|
||||
* <IconSquare className="text-primary" size="lg" />
|
||||
*
|
||||
* @example
|
||||
* // With custom size and onClick handler
|
||||
* <IconSquare size="sm" onClick={handleOnClick} />
|
||||
*/
|
||||
export const IconSquare = createIcon((props) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
{...props}
|
||||
>
|
||||
<rect width="18" height="18" x="3" y="3" rx="2" />
|
||||
</svg>
|
||||
));
|
||||
|
||||
/**
|
||||
* Package2 icon component.
|
||||
*
|
||||
|
||||
@@ -11,10 +11,15 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(
|
||||
// The HTMLvalue will only be updated if the value prop changes, but the user can still type in the input.
|
||||
ref = ref || React.createRef<HTMLInputElement>();
|
||||
React.useEffect(() => {
|
||||
if (ref && ref.current && ref.current.value !== value) {
|
||||
if (
|
||||
ref &&
|
||||
ref.current &&
|
||||
ref.current.value !== value &&
|
||||
type !== "file"
|
||||
) {
|
||||
ref.current.value = value;
|
||||
}
|
||||
}, [value]);
|
||||
}, [value, type]);
|
||||
return (
|
||||
<input
|
||||
type={type}
|
||||
@@ -24,7 +29,7 @@ const Input = React.forwardRef<HTMLInputElement, InputProps>(
|
||||
className,
|
||||
)}
|
||||
ref={ref}
|
||||
defaultValue={value}
|
||||
defaultValue={type !== "file" ? value : undefined}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -31,22 +31,27 @@ export default function useAgentGraph(
|
||||
const [updateQueue, setUpdateQueue] = useState<NodeExecutionResult[]>([]);
|
||||
const processedUpdates = useRef<NodeExecutionResult[]>([]);
|
||||
/**
|
||||
* User `request` to save or save&run the agent
|
||||
* User `request` to save or save&run the agent, or to stop the active run.
|
||||
* `state` is used to track the request status:
|
||||
* - none: no request
|
||||
* - saving: request was sent to save the agent
|
||||
* and nodes are pending sync to update their backend ids
|
||||
* - running: request was sent to run the agent
|
||||
* and frontend is enqueueing execution results
|
||||
* - stopping: a request to stop the active run has been sent; response is pending
|
||||
* - error: request failed
|
||||
*
|
||||
* As of now, state will be stuck at 'running' (if run requested)
|
||||
* because there's no way to know when the execution is done
|
||||
*/
|
||||
const [saveRunRequest, setSaveRunRequest] = useState<{
|
||||
request: "none" | "save" | "run";
|
||||
state: "none" | "saving" | "running" | "error";
|
||||
}>({
|
||||
const [saveRunRequest, setSaveRunRequest] = useState<
|
||||
| {
|
||||
request: "none" | "save" | "run";
|
||||
state: "none" | "saving" | "error";
|
||||
}
|
||||
| {
|
||||
request: "run" | "stop";
|
||||
state: "running" | "stopping" | "error";
|
||||
activeExecutionID?: string;
|
||||
}
|
||||
>({
|
||||
request: "none",
|
||||
state: "none",
|
||||
});
|
||||
@@ -128,13 +133,14 @@ export default function useAgentGraph(
|
||||
console.error("Error saving agent");
|
||||
} else if (saveRunRequest.request === "run") {
|
||||
console.error(`Error saving&running agent`);
|
||||
} else if (saveRunRequest.request === "stop") {
|
||||
console.error(`Error stopping agent`);
|
||||
}
|
||||
// Reset request
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
setSaveRunRequest({
|
||||
request: "none",
|
||||
state: "none",
|
||||
}));
|
||||
});
|
||||
return;
|
||||
}
|
||||
// When saving request is done
|
||||
@@ -145,11 +151,10 @@ export default function useAgentGraph(
|
||||
) {
|
||||
// Reset request if only save was requested
|
||||
if (saveRunRequest.request === "save") {
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
setSaveRunRequest({
|
||||
request: "none",
|
||||
state: "none",
|
||||
}));
|
||||
});
|
||||
// If run was requested, run the agent
|
||||
} else if (saveRunRequest.request === "run") {
|
||||
if (!validateNodes()) {
|
||||
@@ -161,16 +166,64 @@ export default function useAgentGraph(
|
||||
return;
|
||||
}
|
||||
api.subscribeToExecution(savedAgent.id);
|
||||
api.executeGraph(savedAgent.id);
|
||||
processedUpdates.current = processedUpdates.current = [];
|
||||
setSaveRunRequest({ request: "run", state: "running" });
|
||||
api
|
||||
.executeGraph(savedAgent.id)
|
||||
.then((graphExecution) => {
|
||||
setSaveRunRequest({
|
||||
request: "run",
|
||||
state: "running",
|
||||
activeExecutionID: graphExecution.id,
|
||||
});
|
||||
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
request: "run",
|
||||
state: "running",
|
||||
}));
|
||||
// Track execution until completed
|
||||
const pendingNodeExecutions: Set<string> = new Set();
|
||||
const cancelExecListener = api.onWebSocketMessage(
|
||||
"execution_event",
|
||||
(nodeResult) => {
|
||||
// We are racing the server here, since we need the ID to filter events
|
||||
if (nodeResult.graph_exec_id != graphExecution.id) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
nodeResult.status != "COMPLETED" &&
|
||||
nodeResult.status != "FAILED"
|
||||
) {
|
||||
pendingNodeExecutions.add(nodeResult.node_exec_id);
|
||||
} else {
|
||||
pendingNodeExecutions.delete(nodeResult.node_exec_id);
|
||||
}
|
||||
if (pendingNodeExecutions.size == 0) {
|
||||
// Assuming the first event is always a QUEUED node, and
|
||||
// following nodes are QUEUED before all preceding nodes are COMPLETED,
|
||||
// an empty set means the graph has finished running.
|
||||
cancelExecListener();
|
||||
setSaveRunRequest({ request: "none", state: "none" });
|
||||
}
|
||||
},
|
||||
);
|
||||
})
|
||||
.catch(() => setSaveRunRequest({ request: "run", state: "error" }));
|
||||
|
||||
processedUpdates.current = processedUpdates.current = [];
|
||||
}
|
||||
}
|
||||
// Handle stop request
|
||||
if (
|
||||
saveRunRequest.request === "stop" &&
|
||||
saveRunRequest.state != "stopping" &&
|
||||
savedAgent &&
|
||||
saveRunRequest.activeExecutionID
|
||||
) {
|
||||
setSaveRunRequest({
|
||||
request: "stop",
|
||||
state: "stopping",
|
||||
activeExecutionID: saveRunRequest.activeExecutionID,
|
||||
});
|
||||
api
|
||||
.stopGraphExecution(savedAgent.id, saveRunRequest.activeExecutionID)
|
||||
.then(() => setSaveRunRequest({ request: "none", state: "none" }));
|
||||
}
|
||||
}, [saveRunRequest, savedAgent, nodesSyncedWithSavedAgent]);
|
||||
|
||||
// Check if node ids are synced with saved agent
|
||||
@@ -657,7 +710,7 @@ export default function useAgentGraph(
|
||||
[saveAgent],
|
||||
);
|
||||
|
||||
const requestSaveRun = useCallback(() => {
|
||||
const requestSaveAndRun = useCallback(() => {
|
||||
saveAgent();
|
||||
setSaveRunRequest({
|
||||
request: "run",
|
||||
@@ -665,6 +718,23 @@ export default function useAgentGraph(
|
||||
});
|
||||
}, [saveAgent]);
|
||||
|
||||
const requestStopRun = useCallback(() => {
|
||||
if (saveRunRequest.state != "running") {
|
||||
return;
|
||||
}
|
||||
if (!saveRunRequest.activeExecutionID) {
|
||||
console.warn(
|
||||
"Stop requested but execution ID is unknown; state:",
|
||||
saveRunRequest,
|
||||
);
|
||||
}
|
||||
setSaveRunRequest((prev) => ({
|
||||
...prev,
|
||||
request: "stop",
|
||||
state: "running",
|
||||
}));
|
||||
}, [saveRunRequest]);
|
||||
|
||||
return {
|
||||
agentName,
|
||||
setAgentName,
|
||||
@@ -674,7 +744,11 @@ export default function useAgentGraph(
|
||||
availableNodes,
|
||||
getOutputType,
|
||||
requestSave,
|
||||
requestSaveRun,
|
||||
requestSaveAndRun,
|
||||
requestStopRun,
|
||||
isSaving: saveRunRequest.state == "saving",
|
||||
isRunning: saveRunRequest.state == "running",
|
||||
isStopping: saveRunRequest.state == "stopping",
|
||||
nodes,
|
||||
setNodes,
|
||||
edges,
|
||||
|
||||
@@ -15,7 +15,7 @@ export default class AutoGPTServerAPI {
|
||||
private wsUrl: string;
|
||||
private webSocket: WebSocket | null = null;
|
||||
private wsConnecting: Promise<void> | null = null;
|
||||
private wsMessageHandlers: { [key: string]: (data: any) => void } = {};
|
||||
private wsMessageHandlers: Record<string, Set<(data: any) => void>> = {};
|
||||
private supabaseClient = createClient();
|
||||
|
||||
constructor(
|
||||
@@ -128,16 +128,19 @@ export default class AutoGPTServerAPI {
|
||||
runID: string,
|
||||
): Promise<NodeExecutionResult[]> {
|
||||
return (await this._get(`/graphs/${graphID}/executions/${runID}`)).map(
|
||||
(result: any) => ({
|
||||
...result,
|
||||
add_time: new Date(result.add_time),
|
||||
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
|
||||
start_time: result.start_time ? new Date(result.start_time) : undefined,
|
||||
end_time: result.end_time ? new Date(result.end_time) : undefined,
|
||||
}),
|
||||
parseNodeExecutionResultTimestamps,
|
||||
);
|
||||
}
|
||||
|
||||
async stopGraphExecution(
|
||||
graphID: string,
|
||||
runID: string,
|
||||
): Promise<NodeExecutionResult[]> {
|
||||
return (
|
||||
await this._request("POST", `/graphs/${graphID}/executions/${runID}/stop`)
|
||||
).map(parseNodeExecutionResultTimestamps);
|
||||
}
|
||||
|
||||
private async _get(path: string) {
|
||||
return this._request("GET", path);
|
||||
}
|
||||
@@ -207,10 +210,13 @@ export default class AutoGPTServerAPI {
|
||||
};
|
||||
|
||||
this.webSocket.onmessage = (event) => {
|
||||
const message = JSON.parse(event.data);
|
||||
if (this.wsMessageHandlers[message.method]) {
|
||||
this.wsMessageHandlers[message.method](message.data);
|
||||
const message: WebsocketMessage = JSON.parse(event.data);
|
||||
if (message.method == "execution_event") {
|
||||
message.data = parseNodeExecutionResultTimestamps(message.data);
|
||||
}
|
||||
this.wsMessageHandlers[message.method]?.forEach((handler) =>
|
||||
handler(message.data),
|
||||
);
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("Error connecting to WebSocket:", error);
|
||||
@@ -250,8 +256,12 @@ export default class AutoGPTServerAPI {
|
||||
onWebSocketMessage<M extends keyof WebsocketMessageTypeMap>(
|
||||
method: M,
|
||||
handler: (data: WebsocketMessageTypeMap[M]) => void,
|
||||
) {
|
||||
this.wsMessageHandlers[method] = handler;
|
||||
): () => void {
|
||||
this.wsMessageHandlers[method] ??= new Set();
|
||||
this.wsMessageHandlers[method].add(handler);
|
||||
|
||||
// Return detacher
|
||||
return () => this.wsMessageHandlers[method].delete(handler);
|
||||
}
|
||||
|
||||
subscribeToExecution(graphId: string) {
|
||||
@@ -274,3 +284,22 @@ type WebsocketMessageTypeMap = {
|
||||
subscribe: { graph_id: string };
|
||||
execution_event: NodeExecutionResult;
|
||||
};
|
||||
|
||||
type WebsocketMessage = {
|
||||
[M in keyof WebsocketMessageTypeMap]: {
|
||||
method: M;
|
||||
data: WebsocketMessageTypeMap[M];
|
||||
};
|
||||
}[keyof WebsocketMessageTypeMap];
|
||||
|
||||
/* *** HELPER FUNCTIONS *** */
|
||||
|
||||
function parseNodeExecutionResultTimestamps(result: any): NodeExecutionResult {
|
||||
return {
|
||||
...result,
|
||||
add_time: new Date(result.add_time),
|
||||
queue_time: result.queue_time ? new Date(result.queue_time) : undefined,
|
||||
start_time: result.start_time ? new Date(result.start_time) : undefined,
|
||||
end_time: result.end_time ? new Date(result.end_time) : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import {
|
||||
AgentListResponse,
|
||||
AgentDetailResponse,
|
||||
AgentWithRank,
|
||||
FeaturedAgentResponse,
|
||||
UniqueCategoriesResponse,
|
||||
} from "./types";
|
||||
|
||||
export default class MarketplaceAPI {
|
||||
@@ -155,6 +157,42 @@ export default class MarketplaceAPI {
|
||||
});
|
||||
}
|
||||
|
||||
async addFeaturedAgent(
|
||||
agentId: string,
|
||||
categories: string[],
|
||||
): Promise<FeaturedAgentResponse> {
|
||||
const response = await this._post(`/admin/agent/featured/${agentId}`, {
|
||||
categories: categories,
|
||||
});
|
||||
return response;
|
||||
}
|
||||
|
||||
async removeFeaturedAgent(
|
||||
agentId: string,
|
||||
categories: string[],
|
||||
): Promise<FeaturedAgentResponse> {
|
||||
return this._delete(`/admin/agent/featured/${agentId}`, {
|
||||
categories: categories,
|
||||
});
|
||||
}
|
||||
|
||||
async getFeaturedAgent(agentId: string): Promise<FeaturedAgentResponse> {
|
||||
return this._get(`/admin/agent/featured/${agentId}`);
|
||||
}
|
||||
|
||||
async getNotFeaturedAgents(
|
||||
page: number = 1,
|
||||
pageSize: number = 10,
|
||||
): Promise<AgentListResponse> {
|
||||
return this._get(
|
||||
`/admin/agent/not-featured?page=${page}&page_size=${pageSize}`,
|
||||
);
|
||||
}
|
||||
|
||||
async getCategories(): Promise<UniqueCategoriesResponse> {
|
||||
return this._get("/admin/categories");
|
||||
}
|
||||
|
||||
private async _get(path: string) {
|
||||
return this._request("GET", path);
|
||||
}
|
||||
@@ -163,6 +201,10 @@ export default class MarketplaceAPI {
|
||||
return this._request("POST", path, payload);
|
||||
}
|
||||
|
||||
private async _delete(path: string, payload: { [key: string]: any }) {
|
||||
return this._request("DELETE", path, payload);
|
||||
}
|
||||
|
||||
private async _getBlob(path: string): Promise<Blob> {
|
||||
const response = await fetch(this.baseUrl + path);
|
||||
if (!response.ok) {
|
||||
@@ -178,7 +220,7 @@ export default class MarketplaceAPI {
|
||||
}
|
||||
|
||||
private async _request(
|
||||
method: "GET" | "POST" | "PUT" | "PATCH",
|
||||
method: "GET" | "POST" | "PUT" | "PATCH" | "DELETE",
|
||||
path: string,
|
||||
payload?: { [key: string]: any },
|
||||
) {
|
||||
|
||||
@@ -43,6 +43,22 @@ export type AgentList = {
|
||||
total_pages: number;
|
||||
};
|
||||
|
||||
export type FeaturedAgentResponse = {
|
||||
agentId: string;
|
||||
featuredCategories: string[];
|
||||
createdAt: string; // ISO8601 datetime string
|
||||
updatedAt: string; // ISO8601 datetime string
|
||||
isActive: boolean;
|
||||
};
|
||||
|
||||
export type FeaturedAgentsList = {
|
||||
agents: FeaturedAgentResponse[];
|
||||
total_count: number;
|
||||
page: number;
|
||||
page_size: number;
|
||||
total_pages: number;
|
||||
};
|
||||
|
||||
export type AgentDetail = Agent & {
|
||||
graph: Record<string, any>;
|
||||
};
|
||||
@@ -56,3 +72,7 @@ export type AgentListResponse = AgentList;
|
||||
export type AgentDetailResponse = AgentDetail;
|
||||
|
||||
export type AgentResponse = Agent;
|
||||
|
||||
export type UniqueCategoriesResponse = {
|
||||
unique_categories: string[];
|
||||
};
|
||||
|
||||
@@ -838,6 +838,18 @@
|
||||
"@swc/counter" "^0.1.3"
|
||||
tslib "^2.4.0"
|
||||
|
||||
"@tanstack/react-table@^8.20.5":
|
||||
version "8.20.5"
|
||||
resolved "https://registry.yarnpkg.com/@tanstack/react-table/-/react-table-8.20.5.tgz#19987d101e1ea25ef5406dce4352cab3932449d8"
|
||||
integrity sha512-WEHopKw3znbUZ61s9i0+i9g8drmDo6asTWbrQh8Us63DAk/M0FkmIqERew6P71HI75ksZ2Pxyuf4vvKh9rAkiA==
|
||||
dependencies:
|
||||
"@tanstack/table-core" "8.20.5"
|
||||
|
||||
"@tanstack/table-core@8.20.5":
|
||||
version "8.20.5"
|
||||
resolved "https://registry.yarnpkg.com/@tanstack/table-core/-/table-core-8.20.5.tgz#3974f0b090bed11243d4107283824167a395cf1d"
|
||||
integrity sha512-P9dF7XbibHph2PFRz8gfBKEXEY/HJPOhym8CHmjF8y3q5mWpKx9xtZapXQUWCgkqvsK0R46Azuz+VaxD4Xl+Tg==
|
||||
|
||||
"@types/d3-array@^3.0.3":
|
||||
version "3.2.1"
|
||||
resolved "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.1.tgz"
|
||||
|
||||
@@ -4,6 +4,50 @@ DB_NAME=agpt_local
|
||||
DB_PORT=5432
|
||||
DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@localhost:${DB_PORT}/${DB_NAME}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
ENABLE_AUTH="false"
|
||||
|
||||
REDIS_HOST=localhost
|
||||
REDIS_PORT=6379
|
||||
REDIS_PASSWORD=password
|
||||
|
||||
AUTH_ENABLED=false
|
||||
APP_ENV="local"
|
||||
PYRO_HOST=localhost
|
||||
PYRO_HOST=localhost
|
||||
SENTRY_DSN=
|
||||
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
|
||||
# LLM
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
|
||||
# Reddit
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
REDDIT_USERNAME=
|
||||
REDDIT_PASSWORD=
|
||||
|
||||
# Discord
|
||||
DISCORD_BOT_TOKEN=
|
||||
|
||||
# SMTP/Email
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# D-ID
|
||||
DID_API_KEY=
|
||||
|
||||
# Open Weather Map
|
||||
OPENWEATHERMAP_API_KEY=
|
||||
|
||||
# SMTP
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# Medium
|
||||
MEDIUM_API_KEY=
|
||||
MEDIUM_AUTHOR_ID=
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
# LLM
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
|
||||
# Reddit
|
||||
REDDIT_CLIENT_ID=
|
||||
REDDIT_CLIENT_SECRET=
|
||||
REDDIT_USERNAME=
|
||||
REDDIT_PASSWORD=
|
||||
|
||||
# Discord
|
||||
DISCORD_BOT_TOKEN=
|
||||
|
||||
# SMTP/Email
|
||||
SMTP_SERVER=
|
||||
SMTP_PORT=
|
||||
SMTP_USERNAME=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# D-ID
|
||||
DID_API_KEY=
|
||||
@@ -30,25 +30,22 @@ COPY rnd/autogpt_libs /app/rnd/autogpt_libs
|
||||
WORKDIR /app/rnd/autogpt_server
|
||||
|
||||
COPY rnd/autogpt_server/pyproject.toml rnd/autogpt_server/poetry.lock ./
|
||||
|
||||
RUN poetry install --no-interaction --no-ansi
|
||||
|
||||
COPY rnd/autogpt_server /app/rnd/autogpt_server
|
||||
|
||||
WORKDIR /app/rnd/autogpt_server
|
||||
|
||||
COPY rnd/autogpt_server/schema.prisma ./
|
||||
RUN poetry run prisma generate
|
||||
|
||||
COPY rnd/autogpt_server /app/rnd/autogpt_server
|
||||
FROM server_base as server
|
||||
|
||||
ENV PORT=8000
|
||||
ENV DATABASE_URL=""
|
||||
|
||||
CMD ["poetry", "run", "app"]
|
||||
CMD ["poetry", "run", "rest"]
|
||||
|
||||
FROM server_base as manager
|
||||
|
||||
ENV PORT=8002
|
||||
ENV DATABASE_URL=""
|
||||
|
||||
CMD ["poetry", "run", "manager"]
|
||||
CMD ["poetry", "run", "manager"]
|
||||
|
||||
@@ -32,15 +32,14 @@ COPY rnd/autogpt_libs /app/rnd/autogpt_libs
|
||||
WORKDIR /app/rnd/autogpt_server
|
||||
|
||||
COPY rnd/autogpt_server/pyproject.toml rnd/autogpt_server/poetry.lock ./
|
||||
|
||||
RUN poetry install --no-interaction --no-ansi
|
||||
|
||||
COPY rnd/autogpt_server /app/rnd/autogpt_server
|
||||
|
||||
WORKDIR /app/rnd/autogpt_server
|
||||
|
||||
COPY rnd/autogpt_server/schema.prisma ./
|
||||
RUN poetry run prisma generate
|
||||
|
||||
COPY rnd/autogpt_server /app/rnd/autogpt_server
|
||||
FROM server_base as server
|
||||
|
||||
FROM server_base as server
|
||||
|
||||
ENV PORT=8001
|
||||
|
||||
@@ -1,52 +1,40 @@
|
||||
from multiprocessing import freeze_support, set_start_method
|
||||
from typing import TYPE_CHECKING
|
||||
from .util.logging import configure_logging
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt_server.util.process import AppProcess
|
||||
|
||||
|
||||
def run_processes(processes: list["AppProcess"], **kwargs):
|
||||
def run_processes(*processes: "AppProcess", **kwargs):
|
||||
"""
|
||||
Execute all processes in the app. The last process is run in the foreground.
|
||||
"""
|
||||
try:
|
||||
processes[0].start(background=False, **kwargs)
|
||||
except Exception as e:
|
||||
for process in processes[:-1]:
|
||||
process.start(background=True, **kwargs)
|
||||
|
||||
# Run the last process in the foreground
|
||||
processes[-1].start(background=False, **kwargs)
|
||||
finally:
|
||||
for process in processes:
|
||||
process.stop()
|
||||
raise e
|
||||
|
||||
|
||||
def main(**kwargs):
|
||||
set_start_method("spawn", force=True)
|
||||
freeze_support()
|
||||
configure_logging()
|
||||
"""
|
||||
Run all the processes required for the AutoGPT-server (REST and WebSocket APIs).
|
||||
"""
|
||||
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.executor import ExecutionScheduler
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server import AgentServer, WebsocketServer
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
|
||||
run_processes(
|
||||
[
|
||||
ExecutionScheduler(),
|
||||
AgentServer(),
|
||||
],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
def execution_manager(**kwargs):
|
||||
set_start_method("spawn", force=True)
|
||||
freeze_support()
|
||||
configure_logging()
|
||||
|
||||
from autogpt_server.executor import ExecutionManager
|
||||
|
||||
run_processes(
|
||||
[
|
||||
ExecutionManager(),
|
||||
],
|
||||
**kwargs
|
||||
PyroNameServer(),
|
||||
ExecutionManager(),
|
||||
ExecutionScheduler(),
|
||||
WebsocketServer(),
|
||||
AgentServer(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -54,6 +54,15 @@ for cls in all_subclasses(Block):
|
||||
if block.id in AVAILABLE_BLOCKS:
|
||||
raise ValueError(f"Block ID {block.name} error: {block.id} is already in use")
|
||||
|
||||
# Prevent duplicate field name in input_schema and output_schema
|
||||
duplicate_field_names = set(block.input_schema.__fields__.keys()) & set(
|
||||
block.output_schema.__fields__.keys()
|
||||
)
|
||||
if duplicate_field_names:
|
||||
raise ValueError(
|
||||
f"{block.name} has duplicate field names in input_schema and output_schema: {duplicate_field_names}"
|
||||
)
|
||||
|
||||
for field in block.input_schema.__fields__.values():
|
||||
if field.annotation is bool and field.default not in (True, False):
|
||||
raise ValueError(f"{block.name} has a boolean field with no default value")
|
||||
|
||||
@@ -140,7 +140,7 @@ class InputOutputBlockInput(BlockSchema, Generic[T]):
|
||||
|
||||
|
||||
class InputOutputBlockOutput(BlockSchema, Generic[T]):
|
||||
value: T = Field(description="The value passed as input/output.")
|
||||
result: T = Field(description="The value passed as input/output.")
|
||||
|
||||
|
||||
class InputOutputBlockBase(Block, ABC, Generic[T]):
|
||||
@@ -162,8 +162,8 @@ class InputOutputBlockBase(Block, ABC, Generic[T]):
|
||||
{"value": MockObject(value="!!", key="key"), "name": "input_2"},
|
||||
],
|
||||
test_output=[
|
||||
("value", {"apple": 1, "banana": 2, "cherry": 3}),
|
||||
("value", MockObject(value="!!", key="key")),
|
||||
("result", {"apple": 1, "banana": 2, "cherry": 3}),
|
||||
("result", MockObject(value="!!", key="key")),
|
||||
],
|
||||
static_output=True,
|
||||
*args,
|
||||
@@ -171,7 +171,7 @@ class InputOutputBlockBase(Block, ABC, Generic[T]):
|
||||
)
|
||||
|
||||
def run(self, input_data: InputOutputBlockInput[T]) -> BlockOutput:
|
||||
yield "value", input_data.value
|
||||
yield "result", input_data.value
|
||||
|
||||
|
||||
class InputBlock(InputOutputBlockBase[Any]):
|
||||
|
||||
@@ -84,6 +84,9 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
api_key: BlockSecret = SecretField(value="")
|
||||
sys_prompt: str = ""
|
||||
retry: int = 3
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: dict[str, str]
|
||||
@@ -167,6 +170,11 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
lines = s.strip().split("\n")
|
||||
return "\n".join([line.strip().lstrip("|") for line in lines])
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = input_data.prompt.format(**values)
|
||||
input_data.sys_prompt = input_data.sys_prompt.format(**values)
|
||||
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
|
||||
@@ -252,6 +260,9 @@ class AITextGeneratorBlock(Block):
|
||||
api_key: BlockSecret = SecretField(value="")
|
||||
sys_prompt: str = ""
|
||||
retry: int = 3
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
response: str
|
||||
|
||||
@@ -57,7 +57,6 @@ class PublishToMediumBlock(Block):
|
||||
class Output(BlockSchema):
|
||||
post_id: str = SchemaField(description="The ID of the created Medium post")
|
||||
post_url: str = SchemaField(description="The URL of the created Medium post")
|
||||
author_id: str = SchemaField(description="The Medium user ID of the author")
|
||||
published_at: int = SchemaField(
|
||||
description="The timestamp when the post was published"
|
||||
)
|
||||
@@ -85,7 +84,6 @@ class PublishToMediumBlock(Block):
|
||||
test_output=[
|
||||
("post_id", "e6f36a"),
|
||||
("post_url", "https://medium.com/@username/test-post-e6f36a"),
|
||||
("author_id", "1234567890abcdef"),
|
||||
("published_at", 1626282600),
|
||||
],
|
||||
test_mock={
|
||||
@@ -156,7 +154,6 @@ class PublishToMediumBlock(Block):
|
||||
if "data" in response:
|
||||
yield "post_id", response["data"]["id"]
|
||||
yield "post_url", response["data"]["url"]
|
||||
yield "author_id", response["data"]["authorId"]
|
||||
yield "published_at", response["data"]["publishedAt"]
|
||||
else:
|
||||
error_message = response.get("errors", [{}])[0].get(
|
||||
|
||||
@@ -103,14 +103,14 @@ class GetCurrentDateAndTimeBlock(Block):
|
||||
|
||||
class CountdownTimerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
message: Any = "timer finished"
|
||||
input_message: Any = "timer finished"
|
||||
seconds: Union[int, str] = 0
|
||||
minutes: Union[int, str] = 0
|
||||
hours: Union[int, str] = 0
|
||||
days: Union[int, str] = 0
|
||||
|
||||
class Output(BlockSchema):
|
||||
message: str
|
||||
output_message: str
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -121,11 +121,11 @@ class CountdownTimerBlock(Block):
|
||||
output_schema=CountdownTimerBlock.Output,
|
||||
test_input=[
|
||||
{"seconds": 1},
|
||||
{"message": "Custom message"},
|
||||
{"input_message": "Custom message"},
|
||||
],
|
||||
test_output=[
|
||||
("message", "timer finished"),
|
||||
("message", "Custom message"),
|
||||
("output_message", "timer finished"),
|
||||
("output_message", "Custom message"),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -139,4 +139,4 @@ class CountdownTimerBlock(Block):
|
||||
total_seconds = seconds + minutes * 60 + hours * 3600 + days * 86400
|
||||
|
||||
time.sleep(total_seconds)
|
||||
yield "message", input_data.message
|
||||
yield "output_message", input_data.input_message
|
||||
|
||||
@@ -59,6 +59,12 @@ class BlockSchema(BaseModel):
|
||||
return obj
|
||||
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
|
||||
# Set default properties values
|
||||
for field in cls.cached_jsonschema.get("properties", {}).values():
|
||||
if isinstance(field, dict) and "advanced" not in field:
|
||||
field["advanced"] = True
|
||||
|
||||
return cls.cached_jsonschema
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from uuid import uuid4
|
||||
@@ -11,17 +13,35 @@ load_dotenv()
|
||||
PRISMA_SCHEMA = os.getenv("PRISMA_SCHEMA", "schema.prisma")
|
||||
os.environ["PRISMA_SCHEMA_PATH"] = PRISMA_SCHEMA
|
||||
|
||||
prisma = Prisma(auto_register=True)
|
||||
prisma, conn_id = Prisma(auto_register=True), ""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def connect():
|
||||
if not prisma.is_connected():
|
||||
await prisma.connect()
|
||||
async def connect(call_count=0):
|
||||
global conn_id
|
||||
if not conn_id:
|
||||
conn_id = str(uuid4())
|
||||
|
||||
try:
|
||||
logger.info(f"[Prisma-{conn_id}] Acquiring connection..")
|
||||
if not prisma.is_connected():
|
||||
await prisma.connect()
|
||||
logger.info(f"[Prisma-{conn_id}] Connection acquired!")
|
||||
except Exception as e:
|
||||
if call_count <= 5:
|
||||
logger.info(f"[Prisma-{conn_id}] Connection failed: {e}. Retrying now..")
|
||||
await asyncio.sleep(call_count)
|
||||
await connect(call_count + 1)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
async def disconnect():
|
||||
if prisma.is_connected():
|
||||
logger.info(f"[Prisma-{conn_id}] Releasing connection.")
|
||||
await prisma.disconnect()
|
||||
logger.info(f"[Prisma-{conn_id}] Connection released.")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -274,7 +274,10 @@ async def update_node_execution_stats(node_exec_id: str, stats: dict[str, Any]):
|
||||
|
||||
|
||||
async def update_execution_status(
|
||||
node_exec_id: str, status: ExecutionStatus, execution_data: BlockInput | None = None
|
||||
node_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
execution_data: BlockInput | None = None,
|
||||
stats: dict[str, Any] | None = None,
|
||||
) -> ExecutionResult:
|
||||
if status == ExecutionStatus.QUEUED and execution_data is None:
|
||||
raise ValueError("Execution data must be provided when queuing an execution.")
|
||||
@@ -287,6 +290,7 @@ async def update_execution_status(
|
||||
**({"endedTime": now} if status == ExecutionStatus.FAILED else {}),
|
||||
**({"endedTime": now} if status == ExecutionStatus.COMPLETED else {}),
|
||||
**({"executionData": json.dumps(execution_data)} if execution_data else {}),
|
||||
**({"stats": json.dumps(stats)} if stats else {}),
|
||||
}
|
||||
|
||||
res = await AgentNodeExecution.prisma().update(
|
||||
@@ -300,6 +304,26 @@ async def update_execution_status(
|
||||
return ExecutionResult.from_db(res)
|
||||
|
||||
|
||||
async def get_graph_execution(
|
||||
graph_exec_id: str, user_id: str
|
||||
) -> AgentGraphExecution | None:
|
||||
"""
|
||||
Retrieve a specific graph execution by its ID.
|
||||
|
||||
Args:
|
||||
graph_exec_id (str): The ID of the graph execution to retrieve.
|
||||
user_id (str): The ID of the user to whom the graph (execution) belongs.
|
||||
|
||||
Returns:
|
||||
AgentGraphExecution | None: The graph execution if found, None otherwise.
|
||||
"""
|
||||
execution = await AgentGraphExecution.prisma().find_first(
|
||||
where={"id": graph_exec_id, "userId": user_id},
|
||||
include=GRAPH_EXECUTION_INCLUDE,
|
||||
)
|
||||
return execution
|
||||
|
||||
|
||||
async def list_executions(graph_id: str, graph_version: int | None = None) -> list[str]:
|
||||
where: AgentGraphExecutionWhereInput = {"agentGraphId": graph_id}
|
||||
if graph_version is not None:
|
||||
|
||||
@@ -114,13 +114,15 @@ def SchemaField(
|
||||
exclude: bool = False,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
json_extra: dict[str, Any] = {}
|
||||
if placeholder:
|
||||
json_extra["placeholder"] = placeholder
|
||||
if secret:
|
||||
json_extra["secret"] = True
|
||||
if advanced:
|
||||
json_extra["advanced"] = True
|
||||
json_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"placeholder": placeholder,
|
||||
"secret": secret,
|
||||
"advanced": advanced,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
return Field(
|
||||
default,
|
||||
|
||||
@@ -41,7 +41,7 @@ class AsyncRedisEventQueue(AsyncEventQueue):
|
||||
def __init__(self):
|
||||
self.host = os.getenv("REDIS_HOST", "localhost")
|
||||
self.port = int(os.getenv("REDIS_PORT", "6379"))
|
||||
self.password = os.getenv("REDIS_PASSWORD", None)
|
||||
self.password = os.getenv("REDIS_PASSWORD", "password")
|
||||
self.queue_name = os.getenv("REDIS_QUEUE", "execution_events")
|
||||
self.connection = None
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import Future, ProcessPoolExecutor, TimeoutError
|
||||
import multiprocessing
|
||||
import threading
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Generator, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -16,6 +19,7 @@ from autogpt_server.data.execution import (
|
||||
GraphExecution,
|
||||
NodeExecution,
|
||||
create_graph_execution,
|
||||
get_execution_results,
|
||||
get_incomplete_executions,
|
||||
get_latest_execution,
|
||||
merge_execution_input,
|
||||
@@ -30,12 +34,6 @@ from autogpt_server.data.graph import Graph, Link, Node, get_graph, get_node
|
||||
from autogpt_server.util import json
|
||||
from autogpt_server.util.decorator import error_logged, time_measured
|
||||
from autogpt_server.util.logging import configure_logging
|
||||
from autogpt_server.util.metrics import (
|
||||
metric_graph_count,
|
||||
metric_graph_timing,
|
||||
metric_node_payload,
|
||||
metric_node_timing,
|
||||
)
|
||||
from autogpt_server.util.service import AppService, expose, get_service_client
|
||||
from autogpt_server.util.settings import Config
|
||||
from autogpt_server.util.type import convert
|
||||
@@ -65,7 +63,10 @@ ExecutionStream = Generator[NodeExecution, None, None]
|
||||
|
||||
|
||||
def execute_node(
|
||||
loop: asyncio.AbstractEventLoop, api_client: "AgentServer", data: NodeExecution
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
api_client: "AgentServer",
|
||||
data: NodeExecution,
|
||||
execution_stats: dict[str, Any] | None = None,
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -75,6 +76,7 @@ def execute_node(
|
||||
loop: The event loop to run the async functions.
|
||||
api_client: The client to send execution updates to the server.
|
||||
data: The execution data for executing the current node.
|
||||
execution_stats: The execution statistics to be updated.
|
||||
|
||||
Returns:
|
||||
The subsequent node to be enqueued, or None if there is no subsequent node.
|
||||
@@ -120,17 +122,18 @@ def execute_node(
|
||||
|
||||
# Execute the node
|
||||
input_data_str = json.dumps(input_data)
|
||||
metric_node_payload("input_size", len(input_data_str), tags=log_metadata)
|
||||
input_size = len(input_data_str)
|
||||
logger.info(
|
||||
"Executed node with input",
|
||||
extra={"json_fields": {**log_metadata, "input": input_data_str}},
|
||||
)
|
||||
update_execution(ExecutionStatus.RUNNING)
|
||||
|
||||
output_size = 0
|
||||
try:
|
||||
for output_name, output_data in node_block.execute(input_data):
|
||||
output_data_str = json.dumps(output_data)
|
||||
metric_node_payload("output_size", len(output_data_str), tags=log_metadata)
|
||||
output_size += len(output_data_str)
|
||||
logger.info(
|
||||
"Node produced output",
|
||||
extra={"json_fields": {**log_metadata, output_name: output_data_str}},
|
||||
@@ -161,6 +164,11 @@ def execute_node(
|
||||
|
||||
raise e
|
||||
|
||||
finally:
|
||||
if execution_stats is not None:
|
||||
execution_stats["input_size"] = input_size
|
||||
execution_stats["output_size"] = output_size
|
||||
|
||||
|
||||
@contextmanager
|
||||
def synchronized(api_client: "AgentServer", key: Any):
|
||||
@@ -409,23 +417,24 @@ class Executor:
|
||||
node_id=data.node_id,
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, _ = cls._on_node_execution(q, data, log_metadata)
|
||||
metric_node_timing("walltime", timing_info.wall_time, tags=log_metadata)
|
||||
metric_node_timing("cputime", timing_info.cpu_time, tags=log_metadata)
|
||||
|
||||
execution_stats = {}
|
||||
timing_info, _ = cls._on_node_execution(q, data, log_metadata, execution_stats)
|
||||
execution_stats["walltime"] = timing_info.wall_time
|
||||
execution_stats["cputime"] = timing_info.cpu_time
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
update_node_execution_stats(
|
||||
data.node_exec_id,
|
||||
{
|
||||
"walltime": timing_info.wall_time,
|
||||
"cputime": timing_info.cpu_time,
|
||||
},
|
||||
)
|
||||
update_node_execution_stats(data.node_exec_id, execution_stats)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_node_execution(
|
||||
cls, q: ExecutionQueue[NodeExecution], data: NodeExecution, log_metadata: dict
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecution],
|
||||
d: NodeExecution,
|
||||
log_metadata: dict,
|
||||
stats: dict[str, Any] | None = None,
|
||||
):
|
||||
try:
|
||||
cls.logger.info(
|
||||
@@ -436,7 +445,7 @@ class Executor:
|
||||
}
|
||||
},
|
||||
)
|
||||
for execution in execute_node(cls.loop, cls.agent_server_client, data):
|
||||
for execution in execute_node(cls.loop, cls.agent_server_client, d, stats):
|
||||
q.add(execution)
|
||||
cls.logger.info(
|
||||
"Finished node execution",
|
||||
@@ -461,15 +470,19 @@ class Executor:
|
||||
cls.loop = asyncio.new_event_loop()
|
||||
cls.loop.run_until_complete(db.connect())
|
||||
cls.pool_size = Config().num_node_workers
|
||||
cls.executor = ProcessPoolExecutor(
|
||||
max_workers=cls.pool_size,
|
||||
cls._init_node_executor_pool()
|
||||
logger.info(f"Graph executor started with max-{cls.pool_size} node workers.")
|
||||
|
||||
@classmethod
|
||||
def _init_node_executor_pool(cls):
|
||||
cls.executor = Pool(
|
||||
processes=cls.pool_size,
|
||||
initializer=cls.on_node_executor_start,
|
||||
)
|
||||
cls.logger.info(f"Graph executor started with max-{cls.pool_size} node workers")
|
||||
|
||||
@classmethod
|
||||
@error_logged
|
||||
def on_graph_execution(cls, data: GraphExecution):
|
||||
def on_graph_execution(cls, data: GraphExecution, cancel: threading.Event):
|
||||
log_metadata = get_log_metadata(
|
||||
graph_eid=data.graph_exec_id,
|
||||
graph_id=data.graph_id,
|
||||
@@ -477,10 +490,7 @@ class Executor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, node_count = cls._on_graph_execution(data, log_metadata)
|
||||
metric_graph_timing("walltime", timing_info.wall_time, tags=log_metadata)
|
||||
metric_graph_timing("cputime", timing_info.cpu_time, tags=log_metadata)
|
||||
metric_graph_count("nodecount", node_count, tags=log_metadata)
|
||||
timing_info, node_count = cls._on_graph_execution(data, cancel, log_metadata)
|
||||
|
||||
cls.loop.run_until_complete(
|
||||
update_graph_execution_stats(
|
||||
@@ -495,7 +505,9 @@ class Executor:
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_graph_execution(cls, graph_data: GraphExecution, log_metadata: dict) -> int:
|
||||
def _on_graph_execution(
|
||||
cls, graph_data: GraphExecution, cancel: threading.Event, log_metadata: dict
|
||||
) -> int:
|
||||
cls.logger.info(
|
||||
"Start graph execution",
|
||||
extra={
|
||||
@@ -504,38 +516,85 @@ class Executor:
|
||||
}
|
||||
},
|
||||
)
|
||||
node_executed = 0
|
||||
n_node_executions = 0
|
||||
finished = False
|
||||
|
||||
def cancel_handler():
|
||||
while not cancel.is_set():
|
||||
cancel.wait(1)
|
||||
if finished:
|
||||
return
|
||||
cls.executor.terminate()
|
||||
logger.info(
|
||||
f"Terminated graph execution {graph_data.graph_exec_id}",
|
||||
extra={"json_fields": {**log_metadata}},
|
||||
)
|
||||
cls._init_node_executor_pool()
|
||||
|
||||
cancel_thread = threading.Thread(target=cancel_handler)
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
queue = ExecutionQueue[NodeExecution]()
|
||||
for node_exec in graph_data.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
futures: dict[str, Future] = {}
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecution):
|
||||
node_id = exec_data.node_id
|
||||
|
||||
def callback(_):
|
||||
running_executions.pop(node_id)
|
||||
nonlocal n_node_executions
|
||||
n_node_executions += 1
|
||||
|
||||
return callback
|
||||
|
||||
while not queue.empty():
|
||||
execution = queue.get()
|
||||
if cancel.is_set():
|
||||
return n_node_executions
|
||||
|
||||
exec_data = queue.get()
|
||||
|
||||
# Avoid parallel execution of the same node.
|
||||
fut = futures.get(execution.node_id)
|
||||
if fut and not fut.done():
|
||||
execution = running_executions.get(exec_data.node_id)
|
||||
if execution and not execution.ready():
|
||||
# TODO (performance improvement):
|
||||
# Wait for the completion of the same node execution is blocking.
|
||||
# To improve this we need a separate queue for each node.
|
||||
# Re-enqueueing the data back to the queue will disrupt the order.
|
||||
cls.wait_future(fut, timeout=None)
|
||||
execution.wait()
|
||||
|
||||
futures[execution.node_id] = cls.executor.submit(
|
||||
cls.on_node_execution, queue, execution
|
||||
logger.debug(f"Dispatching execution of node {exec_data.node_id}")
|
||||
running_executions[exec_data.node_id] = cls.executor.apply_async(
|
||||
cls.on_node_execution,
|
||||
(queue, exec_data),
|
||||
callback=make_exec_callback(exec_data),
|
||||
)
|
||||
|
||||
# Avoid terminating graph execution when some nodes are still running.
|
||||
while queue.empty() and futures:
|
||||
for node_id, future in list(futures.items()):
|
||||
if future.done():
|
||||
node_executed += 1
|
||||
del futures[node_id]
|
||||
elif queue.empty():
|
||||
cls.wait_future(future)
|
||||
while queue.empty() and running_executions:
|
||||
logger.debug(
|
||||
"Queue empty; running nodes: "
|
||||
f"{list(running_executions.keys())}"
|
||||
)
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
return n_node_executions
|
||||
|
||||
if not queue.empty():
|
||||
logger.debug(
|
||||
"Queue no longer empty! Returning to dispatching loop."
|
||||
)
|
||||
break # yield to parent loop to execute new queue items
|
||||
|
||||
logger.debug(f"Waiting on execution of node {node_id}")
|
||||
execution.wait(3)
|
||||
logger.debug(
|
||||
f"State of execution of node {node_id} after waiting: "
|
||||
f"{'DONE' if execution.ready() else 'RUNNING'}"
|
||||
)
|
||||
|
||||
cls.logger.info(
|
||||
"Finished graph execution",
|
||||
@@ -546,7 +605,7 @@ class Executor:
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
cls.logger.exception(
|
||||
logger.exception(
|
||||
f"Failed graph execution: {e}",
|
||||
extra={
|
||||
"json_fields": {
|
||||
@@ -554,25 +613,21 @@ class Executor:
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return node_executed
|
||||
|
||||
@classmethod
|
||||
def wait_future(cls, future: Future, timeout: int | None = 3):
|
||||
try:
|
||||
if not future.done():
|
||||
future.result(timeout=timeout)
|
||||
except TimeoutError:
|
||||
# Avoid being blocked by long-running node, by not waiting its completion.
|
||||
pass
|
||||
finally:
|
||||
if not cancel.is_set():
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
return n_node_executions
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
super().__init__(port=8002)
|
||||
self.use_db = True
|
||||
self.pool_size = Config().num_graph_workers
|
||||
self.queue = ExecutionQueue[GraphExecution]()
|
||||
self.use_redis = False
|
||||
self.active_graph_runs: dict[str, tuple[Future, threading.Event]] = {}
|
||||
|
||||
# def __del__(self):
|
||||
# self.sync_manager.shutdown()
|
||||
@@ -582,11 +637,21 @@ class ExecutionManager(AppService):
|
||||
max_workers=self.pool_size,
|
||||
initializer=Executor.on_graph_executor_start,
|
||||
) as executor:
|
||||
sync_manager = multiprocessing.Manager()
|
||||
logger.info(
|
||||
f"Execution manager started with max-{self.pool_size} graph workers."
|
||||
)
|
||||
while True:
|
||||
executor.submit(Executor.on_graph_execution, self.queue.get())
|
||||
graph_exec_data = self.queue.get()
|
||||
graph_exec_id = graph_exec_data.graph_exec_id
|
||||
cancel_event = sync_manager.Event()
|
||||
future = executor.submit(
|
||||
Executor.on_graph_execution, graph_exec_data, cancel_event
|
||||
)
|
||||
self.active_graph_runs[graph_exec_id] = (future, cancel_event)
|
||||
future.add_done_callback(
|
||||
lambda _: self.active_graph_runs.pop(graph_exec_id)
|
||||
)
|
||||
|
||||
@property
|
||||
def agent_server_client(self) -> "AgentServer":
|
||||
@@ -595,7 +660,7 @@ class ExecutionManager(AppService):
|
||||
@expose
|
||||
def add_execution(
|
||||
self, graph_id: str, data: BlockInput, user_id: str
|
||||
) -> dict[Any, Any]:
|
||||
) -> dict[str, Any]:
|
||||
graph: Graph | None = self.run_and_wait(get_graph(graph_id, user_id=user_id))
|
||||
if not graph:
|
||||
raise Exception(f"Graph #{graph_id} not found.")
|
||||
@@ -648,4 +713,45 @@ class ExecutionManager(AppService):
|
||||
)
|
||||
self.queue.add(graph_exec)
|
||||
|
||||
return {"id": graph_exec_id}
|
||||
return graph_exec.model_dump()
|
||||
|
||||
@expose
|
||||
def cancel_execution(self, graph_exec_id: str) -> None:
|
||||
"""
|
||||
Mechanism:
|
||||
1. Set the cancel event
|
||||
2. Graph executor's cancel handler thread detects the event, terminates workers,
|
||||
reinitializes worker pool, and returns.
|
||||
3. Update execution statuses in DB and set `error` outputs to `"TERMINATED"`.
|
||||
"""
|
||||
if graph_exec_id not in self.active_graph_runs:
|
||||
raise Exception(
|
||||
f"Graph execution #{graph_exec_id} not active/running: "
|
||||
"possibly already completed/cancelled."
|
||||
)
|
||||
|
||||
future, cancel_event = self.active_graph_runs[graph_exec_id]
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
|
||||
cancel_event.set()
|
||||
future.result()
|
||||
|
||||
# Update the status of the unfinished node executions
|
||||
node_execs = self.run_and_wait(get_execution_results(graph_exec_id))
|
||||
for node_exec in node_execs:
|
||||
if node_exec.status not in (
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
self.run_and_wait(
|
||||
upsert_execution_output(
|
||||
node_exec.node_exec_id, "error", "TERMINATED"
|
||||
)
|
||||
)
|
||||
exec_update = self.run_and_wait(
|
||||
update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
)
|
||||
self.agent_server_client.send_execution_update(exec_update.model_dump())
|
||||
|
||||
@@ -20,6 +20,7 @@ def log(msg, **kwargs):
|
||||
class ExecutionScheduler(AppService):
|
||||
def __init__(self, refresh_interval=10):
|
||||
super().__init__(port=8003)
|
||||
self.use_db = True
|
||||
self.last_check = datetime.min
|
||||
self.refresh_interval = refresh_interval
|
||||
self.use_redis = False
|
||||
|
||||
20
rnd/autogpt_server/autogpt_server/rest.py
Normal file
20
rnd/autogpt_server/autogpt_server/rest.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from autogpt_server.app import run_processes
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
from autogpt_server.server import AgentServer
|
||||
from autogpt_server.util.service import PyroNameServer
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Run all the processes required for the AutoGPT-server REST API.
|
||||
"""
|
||||
run_processes(
|
||||
PyroNameServer(),
|
||||
ExecutionManager(),
|
||||
ExecutionScheduler(),
|
||||
AgentServer(),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,3 +1,4 @@
|
||||
from .rest_api import AgentServer
|
||||
from .ws_api import WebsocketServer
|
||||
|
||||
__all__ = ["AgentServer"]
|
||||
__all__ = ["AgentServer", "WebsocketServer"]
|
||||
|
||||
@@ -11,14 +11,10 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from autogpt_server.data import block, db
|
||||
from autogpt_server.data import execution as execution_db
|
||||
from autogpt_server.data import graph as graph_db
|
||||
from autogpt_server.data import user as user_db
|
||||
from autogpt_server.data.block import BlockInput, CompletedBlockOutput
|
||||
from autogpt_server.data.execution import (
|
||||
ExecutionResult,
|
||||
get_execution_results,
|
||||
list_executions,
|
||||
)
|
||||
from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue
|
||||
from autogpt_server.data.user import get_or_create_user
|
||||
from autogpt_server.executor import ExecutionManager, ExecutionScheduler
|
||||
@@ -33,7 +29,6 @@ settings = Settings()
|
||||
|
||||
class AgentServer(AppService):
|
||||
mutex = KeyedMutex()
|
||||
use_db = False
|
||||
use_redis = True
|
||||
_test_dependency_overrides = {}
|
||||
|
||||
@@ -171,10 +166,15 @@ class AgentServer(AppService):
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{run_id}",
|
||||
endpoint=self.get_run_execution_results,
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}",
|
||||
endpoint=self.get_graph_run_node_execution_results,
|
||||
methods=["GET"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/executions/{graph_exec_id}/stop",
|
||||
endpoint=self.stop_graph_run,
|
||||
methods=["POST"],
|
||||
)
|
||||
router.add_api_route(
|
||||
path="/graphs/{graph_id}/schedules",
|
||||
endpoint=self.create_schedule,
|
||||
@@ -424,15 +424,29 @@ class AgentServer(AppService):
|
||||
graph_id: str,
|
||||
node_input: dict[Any, Any],
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> dict[Any, Any]:
|
||||
) -> dict[str, Any]: # FIXME: add proper return type
|
||||
try:
|
||||
return self.execution_manager_client.add_execution(
|
||||
graph_exec = self.execution_manager_client.add_execution(
|
||||
graph_id, node_input, user_id=user_id
|
||||
)
|
||||
return {"id": graph_exec["graph_exec_id"]}
|
||||
except Exception as e:
|
||||
msg = e.__str__().encode().decode("unicode_escape")
|
||||
raise HTTPException(status_code=400, detail=msg)
|
||||
|
||||
async def stop_graph_run(
|
||||
self, graph_exec_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
if not await execution_db.get_graph_execution(graph_exec_id, user_id):
|
||||
raise HTTPException(
|
||||
404, detail=f"Agent execution #{graph_exec_id} not found"
|
||||
)
|
||||
|
||||
self.execution_manager_client.cancel_execution(graph_exec_id)
|
||||
|
||||
# Retrieve & return canceled graph execution in its final state
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
@classmethod
|
||||
async def get_graph_input_schema(
|
||||
cls,
|
||||
@@ -459,17 +473,20 @@ class AgentServer(AppService):
|
||||
status_code=404, detail=f"Agent #{graph_id}{rev} not found."
|
||||
)
|
||||
|
||||
return await list_executions(graph_id, graph_version)
|
||||
return await execution_db.list_executions(graph_id, graph_version)
|
||||
|
||||
@classmethod
|
||||
async def get_run_execution_results(
|
||||
cls, graph_id: str, run_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> list[ExecutionResult]:
|
||||
async def get_graph_run_node_execution_results(
|
||||
cls,
|
||||
graph_id: str,
|
||||
graph_exec_id: str,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> list[execution_db.ExecutionResult]:
|
||||
graph = await graph_db.get_graph(graph_id, user_id=user_id)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
|
||||
return await get_execution_results(run_id)
|
||||
return await execution_db.get_execution_results(graph_exec_id)
|
||||
|
||||
async def create_schedule(
|
||||
self,
|
||||
@@ -507,7 +524,7 @@ class AgentServer(AppService):
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
execution_result = ExecutionResult(**execution_result_dict)
|
||||
execution_result = execution_db.ExecutionResult(**execution_result_dict)
|
||||
self.run_and_wait(self.event_queue.put(execution_result))
|
||||
|
||||
@expose
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -9,6 +10,7 @@ from autogpt_server.data.queue import AsyncRedisEventQueue
|
||||
from autogpt_server.data.user import DEFAULT_USER_ID
|
||||
from autogpt_server.server.conn_manager import ConnectionManager
|
||||
from autogpt_server.server.model import ExecutionSubscription, Methods, WsMessage
|
||||
from autogpt_server.util.service import AppProcess
|
||||
from autogpt_server.util.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
@@ -166,3 +168,8 @@ async def websocket_router(
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket)
|
||||
logging.info("Client Disconnected")
|
||||
|
||||
|
||||
class WebsocketServer(AppProcess):
|
||||
def run(self):
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
|
||||
@@ -48,13 +48,13 @@ def create_test_graph() -> graph.Graph:
|
||||
graph.Link(
|
||||
source_id=nodes[0].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="value",
|
||||
source_name="result",
|
||||
sink_name="values_#_a",
|
||||
),
|
||||
graph.Link(
|
||||
source_id=nodes[1].id,
|
||||
sink_id=nodes[2].id,
|
||||
source_name="value",
|
||||
source_name="result",
|
||||
sink_name="values_#_b",
|
||||
),
|
||||
graph.Link(
|
||||
|
||||
@@ -12,21 +12,20 @@ class KeyedMutex:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.locks: dict[Any, Lock] = ExpiringDict(max_len=6000, max_age_seconds=60)
|
||||
self.locks: dict[Any, tuple[Lock, int]] = ExpiringDict(
|
||||
max_len=6000, max_age_seconds=60
|
||||
)
|
||||
self.locks_lock = Lock()
|
||||
|
||||
def lock(self, key: Any):
|
||||
with self.locks_lock:
|
||||
if key not in self.locks:
|
||||
self.locks[key] = (lock := Lock())
|
||||
else:
|
||||
lock = self.locks[key]
|
||||
lock, request_count = self.locks.get(key, (Lock(), 0))
|
||||
self.locks[key] = (lock, request_count + 1)
|
||||
lock.acquire()
|
||||
|
||||
def unlock(self, key: Any):
|
||||
with self.locks_lock:
|
||||
if key in self.locks:
|
||||
lock = self.locks.pop(key)
|
||||
else:
|
||||
return
|
||||
lock, request_count = self.locks.pop(key)
|
||||
if request_count > 1:
|
||||
self.locks[key] = (lock, request_count - 1)
|
||||
lock.release()
|
||||
|
||||
@@ -1,38 +1,8 @@
|
||||
import sentry_sdk
|
||||
from sentry_sdk import metrics
|
||||
|
||||
from autogpt_server.util.settings import Settings
|
||||
|
||||
sentry_dsn = Settings().secrets.sentry_dsn
|
||||
sentry_sdk.init(dsn=sentry_dsn, traces_sample_rate=1.0, profiles_sample_rate=1.0)
|
||||
|
||||
|
||||
def emit_distribution(
|
||||
name: str,
|
||||
key: str,
|
||||
value: float,
|
||||
unit: str = "none",
|
||||
tags: dict[str, str] | None = None,
|
||||
):
|
||||
metrics.distribution(
|
||||
key=f"{name}__{key}",
|
||||
value=value,
|
||||
unit=unit,
|
||||
tags=tags or {},
|
||||
)
|
||||
|
||||
|
||||
def metric_node_payload(key: str, value: float, tags: dict[str, str]):
|
||||
emit_distribution("NODE_EXECUTION", key, value, unit="byte", tags=tags)
|
||||
|
||||
|
||||
def metric_node_timing(key: str, value: float, tags: dict[str, str]):
|
||||
emit_distribution("NODE_EXECUTION", key, value, unit="second", tags=tags)
|
||||
|
||||
|
||||
def metric_graph_count(key: str, value: int, tags: dict[str, str]):
|
||||
emit_distribution("GRAPH_EXECUTION", key, value, tags=tags)
|
||||
|
||||
|
||||
def metric_graph_timing(key: str, value: float, tags: dict[str, str]):
|
||||
emit_distribution("GRAPH_EXECUTION", key, value, unit="second", tags=tags)
|
||||
def sentry_init():
|
||||
sentry_dsn = Settings().secrets.sentry_dsn
|
||||
sentry_sdk.init(dsn=sentry_dsn, traces_sample_rate=1.0, profiles_sample_rate=1.0)
|
||||
|
||||
@@ -4,6 +4,9 @@ from abc import ABC, abstractmethod
|
||||
from multiprocessing import Process, set_start_method
|
||||
from typing import Optional
|
||||
|
||||
from autogpt_server.util.logging import configure_logging
|
||||
from autogpt_server.util.metrics import sentry_init
|
||||
|
||||
|
||||
class AppProcess(ABC):
|
||||
"""
|
||||
@@ -11,7 +14,10 @@ class AppProcess(ABC):
|
||||
"""
|
||||
|
||||
process: Optional[Process] = None
|
||||
|
||||
set_start_method("spawn", force=True)
|
||||
configure_logging()
|
||||
sentry_init()
|
||||
|
||||
@abstractmethod
|
||||
def run(self):
|
||||
@@ -20,15 +26,17 @@ class AppProcess(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def health_check(self):
|
||||
"""
|
||||
A method to check the health of the process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def execute_run_command(self, silent):
|
||||
try:
|
||||
if silent:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
sys.stderr = open(os.devnull, "w")
|
||||
else:
|
||||
from .logging import configure_logging
|
||||
|
||||
configure_logging()
|
||||
self.run()
|
||||
except KeyboardInterrupt or SystemExit as e:
|
||||
print(f"Process terminated: {e}")
|
||||
@@ -61,6 +69,7 @@ class AppProcess(ABC):
|
||||
**proc_args,
|
||||
)
|
||||
self.process.start()
|
||||
self.health_check()
|
||||
return self.process.pid or 0
|
||||
|
||||
def stop(self):
|
||||
|
||||
7
rnd/autogpt_server/autogpt_server/util/retry.py
Normal file
7
rnd/autogpt_server/autogpt_server/util/retry.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
conn_retry = retry(
|
||||
stop=stop_after_attempt(30),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=30),
|
||||
reraise=True,
|
||||
)
|
||||
@@ -9,17 +9,14 @@ from typing import Any, Callable, Coroutine, Type, TypeVar, cast
|
||||
import Pyro5.api
|
||||
from Pyro5 import api as pyro
|
||||
from Pyro5 import nameserver
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from autogpt_server.data import db
|
||||
from autogpt_server.data.queue import AsyncEventQueue, AsyncRedisEventQueue
|
||||
from autogpt_server.util.process import AppProcess
|
||||
from autogpt_server.util.retry import conn_retry
|
||||
from autogpt_server.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
conn_retry = retry(
|
||||
stop=stop_after_attempt(30), wait=wait_exponential(multiplier=1, min=1, max=30)
|
||||
)
|
||||
T = TypeVar("T")
|
||||
C = TypeVar("C", bound=Callable)
|
||||
|
||||
@@ -55,11 +52,19 @@ class PyroNameServer(AppProcess):
|
||||
except KeyboardInterrupt:
|
||||
print("Shutting down NameServer")
|
||||
|
||||
@conn_retry
|
||||
def _wait_for_ns(self):
|
||||
pyro.locate_ns(host="localhost", port=9090)
|
||||
print("NameServer is ready")
|
||||
|
||||
def health_check(self):
|
||||
self._wait_for_ns()
|
||||
|
||||
|
||||
class AppService(AppProcess):
|
||||
shared_event_loop: asyncio.AbstractEventLoop
|
||||
event_queue: AsyncEventQueue = AsyncRedisEventQueue()
|
||||
use_db: bool = True
|
||||
use_db: bool = False
|
||||
use_redis: bool = False
|
||||
|
||||
def __init__(self, port):
|
||||
|
||||
@@ -59,7 +59,6 @@ class SpinTestServer:
|
||||
return "3e53486c-cf57-477e-ba2a-cb02dc828e1a"
|
||||
|
||||
async def __aenter__(self):
|
||||
|
||||
self.name_server.__enter__()
|
||||
self.setup_dependency_overrides()
|
||||
self.agent_server.__enter__()
|
||||
@@ -95,7 +94,7 @@ async def wait_execution(
|
||||
timeout: int = 20,
|
||||
) -> list:
|
||||
async def is_execution_completed():
|
||||
execs = await AgentServer().get_run_execution_results(
|
||||
execs = await AgentServer().get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
return (
|
||||
@@ -110,7 +109,7 @@ async def wait_execution(
|
||||
# Wait for the executions to complete
|
||||
for i in range(timeout):
|
||||
if await is_execution_completed():
|
||||
return await AgentServer().get_run_execution_results(
|
||||
return await AgentServer().get_graph_run_node_execution_results(
|
||||
graph_id, graph_exec_id, user_id
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
13
rnd/autogpt_server/autogpt_server/ws.py
Normal file
13
rnd/autogpt_server/autogpt_server/ws.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from autogpt_server.app import run_processes
|
||||
from autogpt_server.server.ws_api import WebsocketServer
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Run all the processes required for the AutoGPT-server WebSocket API.
|
||||
"""
|
||||
run_processes(WebsocketServer())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,11 +0,0 @@
|
||||
import uvicorn
|
||||
|
||||
from autogpt_server.server.ws_api import app
|
||||
|
||||
|
||||
def main():
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,4 @@
|
||||
{
|
||||
"num_graph_workers": 1,
|
||||
"num_node_workers": 1
|
||||
"num_graph_workers": 10,
|
||||
"num_node_workers": 5
|
||||
}
|
||||
|
||||
4
rnd/autogpt_server/poetry.lock
generated
4
rnd/autogpt_server/poetry.lock
generated
@@ -25,7 +25,7 @@ requests = "*"
|
||||
sentry-sdk = "^1.40.4"
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["agbenchmark"]
|
||||
benchmark = ["agbenchmark @ file:///Users/majdyz/Code/AutoGPT/benchmark"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
@@ -386,7 +386,7 @@ watchdog = "4.0.0"
|
||||
webdriver-manager = "^4.0.1"
|
||||
|
||||
[package.extras]
|
||||
benchmark = ["agbenchmark"]
|
||||
benchmark = ["agbenchmark @ file:///Users/majdyz/Code/AutoGPT/benchmark"]
|
||||
|
||||
[package.source]
|
||||
type = "directory"
|
||||
|
||||
@@ -64,7 +64,8 @@ build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
app = "autogpt_server.app:main"
|
||||
ws = "autogpt_server.ws_app:main"
|
||||
rest = "autogpt_server.rest:main"
|
||||
ws = "autogpt_server.ws:main"
|
||||
manager = "autogpt_server.app:execution_manager"
|
||||
cli = "autogpt_server.cli:main"
|
||||
format = "linter:format"
|
||||
|
||||
@@ -35,11 +35,11 @@ async def assert_sample_graph_executions(
|
||||
test_user: User,
|
||||
graph_exec_id: str,
|
||||
):
|
||||
executions = await agent_server.get_run_execution_results(
|
||||
executions = await agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
|
||||
output_list = [{"value": ["Hello"]}, {"value": ["World"]}]
|
||||
output_list = [{"result": ["Hello"]}, {"result": ["World"]}]
|
||||
input_list = [
|
||||
{"value": "Hello", "name": "input_1"},
|
||||
{"value": "World", "name": "input_2"},
|
||||
@@ -156,7 +156,7 @@ async def test_input_pin_always_waited(server: SpinTestServer):
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 3
|
||||
)
|
||||
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 3
|
||||
@@ -236,7 +236,7 @@ async def test_static_input_link_on_graph(server: SpinTestServer):
|
||||
graph_exec_id = await execute_graph(
|
||||
server.agent_server, server.exec_manager, test_graph, test_user, {}, 8
|
||||
)
|
||||
executions = await server.agent_server.get_run_execution_results(
|
||||
executions = await server.agent_server.get_graph_run_node_execution_results(
|
||||
test_graph.id, graph_exec_id, test_user.id
|
||||
)
|
||||
assert len(executions) == 8
|
||||
|
||||
@@ -82,6 +82,6 @@ env:
|
||||
APP_ENV: "dev"
|
||||
PYRO_HOST: "0.0.0.0"
|
||||
NUM_GRAPH_WORKERS: 100
|
||||
NUM_NODE_WORKERS: 100
|
||||
NUM_NODE_WORKERS: 5
|
||||
REDIS_HOST: "redis-dev-master.redis-dev.svc.cluster.local"
|
||||
REDIS_PORT: "6379"
|
||||
REDIS_PORT: "6379"
|
||||
|
||||
@@ -65,7 +65,7 @@ app.add_middleware(
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
"https://dev-builder.agpt.co"
|
||||
"https://dev-builder.agpt.co",
|
||||
],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
|
||||
@@ -403,7 +403,7 @@ async def get_top_agents_by_downloads(
|
||||
|
||||
async def set_agent_featured(
|
||||
agent_id: str, is_active: bool = True, featured_categories: list[str] = ["featured"]
|
||||
):
|
||||
) -> prisma.models.FeaturedAgent:
|
||||
"""Set an agent as featured in the database.
|
||||
|
||||
Args:
|
||||
@@ -418,7 +418,7 @@ async def set_agent_featured(
|
||||
if not agent:
|
||||
raise AgentQueryError(f"Agent with ID {agent_id} not found.")
|
||||
|
||||
await prisma.models.FeaturedAgent.prisma().upsert(
|
||||
featured = await prisma.models.FeaturedAgent.prisma().upsert(
|
||||
where={"agentId": agent_id},
|
||||
data={
|
||||
"update": {
|
||||
@@ -432,6 +432,7 @@ async def set_agent_featured(
|
||||
},
|
||||
},
|
||||
)
|
||||
return featured
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise AgentQueryError(f"Database query failed: {str(e)}")
|
||||
@@ -553,3 +554,90 @@ async def add_featured_category(
|
||||
raise AgentQueryError(f"Database query failed: {str(e)}")
|
||||
except Exception as e:
|
||||
raise AgentQueryError(f"Unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
async def get_agent_featured(agent_id: str) -> prisma.models.FeaturedAgent | None:
|
||||
"""Retrieve an agent's featured categories from the database.
|
||||
|
||||
Args:
|
||||
agent_id (str): The ID of the agent.
|
||||
|
||||
Returns:
|
||||
FeaturedAgentResponse: The list of featured agents.
|
||||
"""
|
||||
try:
|
||||
featured_agent = await prisma.models.FeaturedAgent.prisma().find_unique(
|
||||
where={"agentId": agent_id},
|
||||
)
|
||||
return featured_agent
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise AgentQueryError(f"Database query failed: {str(e)}")
|
||||
except Exception as e:
|
||||
raise AgentQueryError(f"Unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
async def get_not_featured_agents(
|
||||
page: int = 1, page_size: int = 10
|
||||
) -> typing.List[prisma.models.Agents]:
|
||||
"""
|
||||
Retrieve a list of not featured agents from the database.
|
||||
"""
|
||||
try:
|
||||
agents = await prisma.client.get_client().query_raw(
|
||||
query=f"""
|
||||
SELECT
|
||||
"Agents".id,
|
||||
"Agents"."createdAt",
|
||||
"Agents"."updatedAt",
|
||||
"Agents".version,
|
||||
"Agents".name,
|
||||
LEFT("Agents".description, 500) AS description,
|
||||
"Agents".author,
|
||||
"Agents".keywords,
|
||||
"Agents".categories,
|
||||
"Agents".graph,
|
||||
"Agents"."submissionStatus",
|
||||
"Agents"."submissionDate",
|
||||
"Agents".search::text AS search
|
||||
FROM "Agents"
|
||||
LEFT JOIN "FeaturedAgent" ON "Agents"."id" = "FeaturedAgent"."agentId"
|
||||
WHERE ("FeaturedAgent"."agentId" IS NULL OR "FeaturedAgent"."featuredCategories" = '{{}}')
|
||||
AND "Agents"."submissionStatus" = 'APPROVED'
|
||||
ORDER BY "Agents"."createdAt" DESC
|
||||
LIMIT {page_size} OFFSET {page_size * (page - 1)}
|
||||
""",
|
||||
model=prisma.models.Agents,
|
||||
)
|
||||
return agents
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise AgentQueryError(f"Database query failed: {str(e)}")
|
||||
except Exception as e:
|
||||
raise AgentQueryError(f"Unexpected error occurred: {str(e)}")
|
||||
|
||||
|
||||
async def get_all_categories() -> market.model.CategoriesResponse:
|
||||
"""
|
||||
Retrieve all unique categories from the database.
|
||||
|
||||
Returns:
|
||||
CategoriesResponse: A list of unique categories.
|
||||
"""
|
||||
try:
|
||||
categories = await prisma.client.get_client().query_first(
|
||||
query="""
|
||||
SELECT ARRAY_AGG(DISTINCT category ORDER BY category) AS unique_categories
|
||||
FROM (
|
||||
SELECT UNNEST(categories) AS category
|
||||
FROM "Agents"
|
||||
) subquery;
|
||||
""",
|
||||
model=market.model.CategoriesResponse,
|
||||
)
|
||||
if not categories:
|
||||
raise AgentQueryError("No categories found")
|
||||
|
||||
return categories
|
||||
except prisma.errors.PrismaError as e:
|
||||
raise AgentQueryError(f"Database query failed: {str(e)}")
|
||||
except Exception as e:
|
||||
raise AgentQueryError(f"Unexpected error occurred: {str(e)}")
|
||||
|
||||
@@ -95,3 +95,26 @@ class AgentDetailResponse(pydantic.BaseModel):
|
||||
createdAt: datetime.datetime
|
||||
updatedAt: datetime.datetime
|
||||
graph: dict[str, typing.Any]
|
||||
|
||||
|
||||
class FeaturedAgentResponse(pydantic.BaseModel):
|
||||
"""
|
||||
Represents the response data for an agent detail.
|
||||
"""
|
||||
|
||||
agentId: str
|
||||
featuredCategories: list[str]
|
||||
createdAt: datetime.datetime
|
||||
updatedAt: datetime.datetime
|
||||
isActive: bool
|
||||
|
||||
|
||||
class CategoriesResponse(pydantic.BaseModel):
|
||||
"""
|
||||
Represents the response data for a list of categories.
|
||||
|
||||
Attributes:
|
||||
unique_categories (list[str]): The list of unique categories.
|
||||
"""
|
||||
|
||||
unique_categories: list[str]
|
||||
|
||||
@@ -46,19 +46,44 @@ async def create_agent_entry(
|
||||
@router.post("/agent/featured/{agent_id}")
|
||||
async def set_agent_featured(
|
||||
agent_id: str,
|
||||
category: list[str] = ["featured"],
|
||||
categories: list[str] = fastapi.Query(
|
||||
default=["featured"],
|
||||
description="The categories to set the agent as featured in",
|
||||
),
|
||||
user: autogpt_libs.auth.User = fastapi.Depends(
|
||||
autogpt_libs.auth.requires_admin_user
|
||||
),
|
||||
):
|
||||
) -> market.model.FeaturedAgentResponse:
|
||||
"""
|
||||
A basic endpoint to set an agent as featured in the database.
|
||||
"""
|
||||
try:
|
||||
await market.db.set_agent_featured(
|
||||
agent_id, is_active=True, featured_categories=category
|
||||
agent = await market.db.set_agent_featured(
|
||||
agent_id, is_active=True, featured_categories=categories
|
||||
)
|
||||
return fastapi.responses.Response(status_code=200)
|
||||
return market.model.FeaturedAgentResponse(**agent.model_dump())
|
||||
except market.db.AgentQueryError as e:
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as e:
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/agent/featured/{agent_id}")
|
||||
async def get_agent_featured(
|
||||
agent_id: str,
|
||||
user: autogpt_libs.auth.User = fastapi.Depends(
|
||||
autogpt_libs.auth.requires_admin_user
|
||||
),
|
||||
) -> market.model.FeaturedAgentResponse | None:
|
||||
"""
|
||||
A basic endpoint to get an agent as featured in the database.
|
||||
"""
|
||||
try:
|
||||
agent = await market.db.get_agent_featured(agent_id)
|
||||
if agent:
|
||||
return market.model.FeaturedAgentResponse(**agent.model_dump())
|
||||
else:
|
||||
return None
|
||||
except market.db.AgentQueryError as e:
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as e:
|
||||
@@ -72,14 +97,46 @@ async def unset_agent_featured(
|
||||
user: autogpt_libs.auth.User = fastapi.Depends(
|
||||
autogpt_libs.auth.requires_admin_user
|
||||
),
|
||||
):
|
||||
) -> market.model.FeaturedAgentResponse | None:
|
||||
"""
|
||||
A basic endpoint to unset an agent as featured in the database.
|
||||
"""
|
||||
try:
|
||||
featured = await market.db.remove_featured_category(agent_id, category=category)
|
||||
if featured:
|
||||
return market.model.FeaturedAgentResponse(**featured.model_dump())
|
||||
else:
|
||||
return None
|
||||
except market.db.AgentQueryError as e:
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as e:
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
await market.db.remove_featured_category(agent_id, category=category)
|
||||
return fastapi.responses.Response(status_code=200)
|
||||
|
||||
@router.get("/agent/not-featured")
|
||||
async def get_not_featured_agents(
|
||||
page: int = fastapi.Query(1, ge=1, description="Page number"),
|
||||
page_size: int = fastapi.Query(
|
||||
10, ge=1, le=100, description="Number of items per page"
|
||||
),
|
||||
user: autogpt_libs.auth.User = fastapi.Depends(
|
||||
autogpt_libs.auth.requires_admin_user
|
||||
),
|
||||
) -> market.model.AgentListResponse:
|
||||
"""
|
||||
A basic endpoint to get all not featured agents in the database.
|
||||
"""
|
||||
try:
|
||||
agents = await market.db.get_not_featured_agents(page=page, page_size=page_size)
|
||||
return market.model.AgentListResponse(
|
||||
agents=[
|
||||
market.model.AgentResponse(**agent.model_dump()) for agent in agents
|
||||
],
|
||||
total_count=len(agents),
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=999,
|
||||
)
|
||||
except market.db.AgentQueryError as e:
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as e:
|
||||
@@ -177,3 +234,15 @@ async def review_submission(
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as e:
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/categories")
|
||||
async def get_categories() -> market.model.CategoriesResponse:
|
||||
"""
|
||||
A basic endpoint to get all available categories.
|
||||
"""
|
||||
try:
|
||||
categories = await market.db.get_all_categories()
|
||||
return categories
|
||||
except Exception as e:
|
||||
raise fastapi.HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
Reference in New Issue
Block a user