Files
sim/stores/workflow/store.ts
2025-02-15 12:13:27 -08:00

457 lines
13 KiB
TypeScript

import { Edge } from 'reactflow'
import { create } from 'zustand'
import { devtools } from 'zustand/middleware'
import { getBlock } from '@/blocks'
import { resolveOutputType } from '@/blocks/utils'
import { WorkflowStoreWithHistory, pushHistory, withHistory } from './middleware'
import { Loop, Position, SubBlockState } from './types'
import { detectCycle } from './utils'
const initialState = {
blocks: {},
edges: [],
loops: {},
lastSaved: undefined,
history: {
past: [],
present: {
state: { blocks: {}, edges: [], loops: {} },
timestamp: Date.now(),
action: 'Initial state',
},
future: [],
},
}
export const useWorkflowStore = create<WorkflowStoreWithHistory>()(
devtools(
withHistory((set, get) => ({
...initialState,
undo: () => {},
redo: () => {},
canUndo: () => false,
canRedo: () => false,
revertToHistoryState: () => {},
updateSubBlock: (blockId: string, subBlockId: string, value: any) => {
set((state) => {
const block = state.blocks[blockId]
if (!block) return state
// Handle different value types appropriately
const processedValue = Array.isArray(value)
? value
: block.subBlocks[subBlockId]?.type === 'slider'
? Number(value) // Convert slider values to numbers
: typeof value === 'string'
? value
: JSON.stringify(value, null, 2)
// Only attempt JSON parsing for agent responseFormat validation
if (
block.type === 'agent' &&
subBlockId === 'responseFormat' &&
typeof processedValue === 'string'
) {
try {
// Parse the input string to validate JSON but keep original string value
const parsed = JSON.parse(processedValue)
// Simple validation of required schema structure
if (!parsed.fields || !Array.isArray(parsed.fields)) {
console.error('Validation failed: missing fields array')
throw new Error('Response format must have a fields array')
}
for (const field of parsed.fields) {
if (!field.name || !field.type) {
console.error('Validation failed: field missing name or type', field)
throw new Error('Each field must have a name and type')
}
if (!['string', 'number', 'boolean', 'array', 'object'].includes(field.type)) {
console.error('Validation failed: invalid field type', field)
throw new Error(
`Invalid type "${field.type}" - must be one of: string, number, boolean, array, object`
)
}
}
// Don't modify the value, keep it as the original string
} catch (error: any) {
console.error('responseFormat validation error:', error)
throw new Error(`Invalid JSON schema: ${error.message}`)
}
}
return {
blocks: {
...state.blocks,
[blockId]: {
...block,
subBlocks: {
...block.subBlocks,
[subBlockId]: {
...block.subBlocks[subBlockId],
value: processedValue,
},
},
},
},
}
})
get().updateLastSaved()
},
addBlock: (id: string, type: string, name: string, position: Position) => {
const blockConfig = getBlock(type)
if (!blockConfig) return
const subBlocks: Record<string, SubBlockState> = {}
blockConfig.subBlocks.forEach((subBlock) => {
const subBlockId = subBlock.id
subBlocks[subBlockId] = {
id: subBlockId,
type: subBlock.type,
value: null,
}
})
const outputs = resolveOutputType(blockConfig.outputs, subBlocks)
const newState = {
blocks: {
...get().blocks,
[id]: {
id,
type,
name,
position,
subBlocks,
outputs,
enabled: true,
horizontalHandles: true,
isWide: false,
height: 0,
},
},
edges: [...get().edges],
loops: { ...get().loops },
}
set(newState)
pushHistory(set, get, newState, `Add ${type} block`)
get().updateLastSaved()
},
updateBlockPosition: (id: string, position: Position) => {
set((state) => ({
blocks: {
...state.blocks,
[id]: {
...state.blocks[id],
position,
},
},
edges: [...state.edges],
}))
get().updateLastSaved()
},
removeBlock: (id: string) => {
const newState = {
blocks: { ...get().blocks },
edges: [...get().edges].filter((edge) => edge.source !== id && edge.target !== id),
loops: { ...get().loops },
}
// Remove the block from any loops that contain it
Object.entries(newState.loops).forEach(([loopId, loop]) => {
if (loop.nodes.includes(id)) {
// If the loop would only have 1 or 0 nodes after removal, delete the loop
if (loop.nodes.length <= 2) {
delete newState.loops[loopId]
} else {
// Otherwise, just remove the node from the loop
newState.loops[loopId] = {
...loop,
nodes: loop.nodes.filter((nodeId) => nodeId !== id),
}
}
}
})
// Delete the block itself
delete newState.blocks[id]
set(newState)
pushHistory(set, get, newState, 'Remove block')
get().updateLastSaved()
},
addEdge: (edge: Edge) => {
// Check for duplicate connections
const isDuplicate = get().edges.some(
(existingEdge) =>
existingEdge.source === edge.source &&
existingEdge.target === edge.target &&
existingEdge.sourceHandle === edge.sourceHandle &&
existingEdge.targetHandle === edge.targetHandle
)
// If it's a duplicate connection, return early without adding the edge
if (isDuplicate) {
return
}
const newEdge = {
id: edge.id || crypto.randomUUID(),
source: edge.source,
target: edge.target,
sourceHandle: edge.sourceHandle,
targetHandle: edge.targetHandle,
}
const newEdges = [...get().edges, newEdge]
// Recalculate all loops after adding the edge
const newLoops: Record<string, Loop> = {}
const processedPaths = new Set<string>()
// Check for cycles from each node
const nodes = new Set(newEdges.map((e) => e.source))
nodes.forEach((node) => {
const { paths } = detectCycle(newEdges, node)
paths.forEach((path) => {
// Create a canonical path representation for deduplication
const canonicalPath = [...path].sort().join(',')
if (!processedPaths.has(canonicalPath)) {
const loopId = crypto.randomUUID()
newLoops[loopId] = {
id: loopId,
nodes: path,
maxIterations: 5,
}
processedPaths.add(canonicalPath)
}
})
})
const newState = {
blocks: { ...get().blocks },
edges: newEdges,
loops: newLoops,
}
set(newState)
pushHistory(set, get, newState, 'Add connection')
get().updateLastSaved()
},
removeEdge: (edgeId: string) => {
const newEdges = get().edges.filter((edge) => edge.id !== edgeId)
// Recalculate all loops after edge removal
const newLoops: Record<string, Loop> = {}
const processedPaths = new Set<string>()
// Check for cycles from each node
const nodes = new Set(newEdges.map((e) => e.source))
nodes.forEach((node) => {
const { paths } = detectCycle(newEdges, node)
paths.forEach((path) => {
// Create a canonical path representation for deduplication
const canonicalPath = [...path].sort().join(',')
if (!processedPaths.has(canonicalPath)) {
const loopId = crypto.randomUUID()
newLoops[loopId] = {
id: loopId,
nodes: path,
maxIterations: 5,
}
processedPaths.add(canonicalPath)
}
})
})
const newState = {
blocks: { ...get().blocks },
edges: newEdges,
loops: newLoops,
}
set(newState)
pushHistory(set, get, newState, 'Remove connection')
get().updateLastSaved()
},
clear: () => {
const newState = {
blocks: {},
edges: [],
loops: {},
history: {
past: [],
present: {
state: { blocks: {}, edges: [], loops: {} },
timestamp: Date.now(),
action: 'Initial state',
},
future: [],
},
lastSaved: Date.now(),
}
set(newState)
return newState
},
updateLastSaved: () => {
set({ lastSaved: Date.now() })
},
toggleBlockEnabled: (id: string) => {
const newState = {
blocks: {
...get().blocks,
[id]: {
...get().blocks[id],
enabled: !get().blocks[id].enabled,
},
},
edges: [...get().edges],
}
set(newState)
get().updateLastSaved()
},
duplicateBlock: (id: string) => {
const block = get().blocks[id]
if (!block) return
const newId = crypto.randomUUID()
const offsetPosition = {
x: block.position.x + 250,
y: block.position.y + 20,
}
// More efficient name handling
const match = block.name.match(/(.*?)(\d+)?$/)
const newName =
match && match[2] ? `${match[1]}${parseInt(match[2]) + 1}` : `${block.name} 1`
const newSubBlocks = Object.entries(block.subBlocks).reduce(
(acc, [subId, subBlock]) => ({
...acc,
[subId]: {
...subBlock,
value: JSON.parse(JSON.stringify(subBlock.value)),
},
}),
{}
)
const newState = {
blocks: {
...get().blocks,
[newId]: {
...block,
id: newId,
name: newName,
position: offsetPosition,
subBlocks: newSubBlocks,
},
},
edges: [...get().edges],
loops: { ...get().loops },
}
set(newState)
pushHistory(set, get, newState, `Duplicate ${block.type} block`)
get().updateLastSaved()
},
toggleBlockHandles: (id: string) => {
const newState = {
blocks: {
...get().blocks,
[id]: {
...get().blocks[id],
horizontalHandles: !get().blocks[id].horizontalHandles,
},
},
edges: [...get().edges],
}
set(newState)
get().updateLastSaved()
},
updateBlockName: (id: string, name: string) => {
const newState = {
blocks: {
...get().blocks,
[id]: {
...get().blocks[id],
name,
},
},
edges: [...get().edges],
loops: { ...get().loops },
}
set(newState)
pushHistory(set, get, newState, `${name} block name updated`)
get().updateLastSaved()
},
toggleBlockWide: (id: string) => {
set((state) => ({
blocks: {
...state.blocks,
[id]: {
...state.blocks[id],
isWide: !state.blocks[id].isWide,
},
},
edges: [...state.edges],
loops: { ...get().loops },
}))
get().updateLastSaved()
},
updateBlockHeight: (id: string, height: number) => {
set((state) => ({
blocks: {
...state.blocks,
[id]: {
...state.blocks[id],
height,
},
},
edges: [...state.edges],
}))
get().updateLastSaved()
},
updateLoopMaxIterations: (loopId: string, maxIterations: number) => {
const newState = {
blocks: { ...get().blocks },
edges: [...get().edges],
loops: {
...get().loops,
[loopId]: {
...get().loops[loopId],
maxIterations: Math.max(1, Math.min(50, maxIterations)), // Clamp between 1-50
},
},
}
set(newState)
pushHistory(set, get, newState, 'Update loop max iterations')
get().updateLastSaved()
},
})),
{ name: 'workflow-store' }
)
)