mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-31 09:48:06 -05:00
Compare commits
4 Commits
main
...
fix/condit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
48baa41459 | ||
|
|
d0f1f32e8d | ||
|
|
2d799b3272 | ||
|
|
92403e0594 |
@@ -150,7 +150,9 @@ export function Editor() {
|
|||||||
blockSubBlockValues,
|
blockSubBlockValues,
|
||||||
canonicalIndex
|
canonicalIndex
|
||||||
)
|
)
|
||||||
const displayAdvancedOptions = advancedMode || advancedValuesPresent
|
const displayAdvancedOptions = userPermissions.canEdit
|
||||||
|
? advancedMode
|
||||||
|
: advancedMode || advancedValuesPresent
|
||||||
|
|
||||||
const hasAdvancedOnlyFields = useMemo(() => {
|
const hasAdvancedOnlyFields = useMemo(() => {
|
||||||
for (const subBlock of subBlocksForCanonical) {
|
for (const subBlock of subBlocksForCanonical) {
|
||||||
|
|||||||
@@ -322,7 +322,8 @@ describe('ConditionBlockHandler', () => {
|
|||||||
|
|
||||||
await handler.execute(mockContext, mockBlock, inputs)
|
await handler.execute(mockContext, mockBlock, inputs)
|
||||||
|
|
||||||
expect(mockCollectBlockData).toHaveBeenCalledWith(mockContext)
|
// collectBlockData is now called with the current node ID for parallel branch context
|
||||||
|
expect(mockCollectBlockData).toHaveBeenCalledWith(mockContext, mockBlock.id)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle function_execute tool failure', async () => {
|
it('should handle function_execute tool failure', async () => {
|
||||||
@@ -620,4 +621,248 @@ describe('ConditionBlockHandler', () => {
|
|||||||
expect(mockContext.decisions.condition.has(mockBlock.id)).toBe(false)
|
expect(mockContext.decisions.condition.has(mockBlock.id)).toBe(false)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('Parallel branch handling', () => {
|
||||||
|
it('should resolve connections and block data correctly when inside a parallel branch', async () => {
|
||||||
|
// Simulate a condition block inside a parallel branch
|
||||||
|
// Virtual block ID uses subscript notation: blockId₍branchIndex₎
|
||||||
|
const parallelConditionBlock: SerializedBlock = {
|
||||||
|
id: 'cond-block-1₍0₎', // Virtual ID for branch 0
|
||||||
|
metadata: { id: 'condition', name: 'Condition' },
|
||||||
|
position: { x: 0, y: 0 },
|
||||||
|
config: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Source block also has a virtual ID in the same branch
|
||||||
|
const sourceBlockVirtualId = 'agent-block-1₍0₎'
|
||||||
|
|
||||||
|
// Set up workflow with connections using BASE block IDs (as they are in the workflow definition)
|
||||||
|
const parallelWorkflow: SerializedWorkflow = {
|
||||||
|
blocks: [
|
||||||
|
{
|
||||||
|
id: 'agent-block-1',
|
||||||
|
metadata: { id: 'agent', name: 'Agent' },
|
||||||
|
position: { x: 0, y: 0 },
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'cond-block-1',
|
||||||
|
metadata: { id: 'condition', name: 'Condition' },
|
||||||
|
position: { x: 100, y: 0 },
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'target-block-1',
|
||||||
|
metadata: { id: 'api', name: 'Target' },
|
||||||
|
position: { x: 200, y: 0 },
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
connections: [
|
||||||
|
// Connections use base IDs, not virtual IDs
|
||||||
|
{ source: 'agent-block-1', target: 'cond-block-1' },
|
||||||
|
{ source: 'cond-block-1', target: 'target-block-1', sourceHandle: 'condition-cond1' },
|
||||||
|
],
|
||||||
|
loops: [],
|
||||||
|
parallels: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
// Block states use virtual IDs (as outputs are stored per-branch)
|
||||||
|
const parallelBlockStates = new Map<string, BlockState>([
|
||||||
|
[
|
||||||
|
sourceBlockVirtualId,
|
||||||
|
{ output: { response: 'hello from branch 0', success: true }, executed: true },
|
||||||
|
],
|
||||||
|
])
|
||||||
|
|
||||||
|
const parallelContext: ExecutionContext = {
|
||||||
|
workflowId: 'test-workflow-id',
|
||||||
|
workspaceId: 'test-workspace-id',
|
||||||
|
workflow: parallelWorkflow,
|
||||||
|
blockStates: parallelBlockStates,
|
||||||
|
blockLogs: [],
|
||||||
|
completedBlocks: new Set(),
|
||||||
|
decisions: {
|
||||||
|
router: new Map(),
|
||||||
|
condition: new Map(),
|
||||||
|
},
|
||||||
|
environmentVariables: {},
|
||||||
|
workflowVariables: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
const conditions = [
|
||||||
|
{ id: 'cond1', title: 'if', value: 'context.response === "hello from branch 0"' },
|
||||||
|
{ id: 'else1', title: 'else', value: '' },
|
||||||
|
]
|
||||||
|
const inputs = { conditions: JSON.stringify(conditions) }
|
||||||
|
|
||||||
|
const result = await handler.execute(parallelContext, parallelConditionBlock, inputs)
|
||||||
|
|
||||||
|
// The condition should evaluate to true because:
|
||||||
|
// 1. Connection lookup uses base ID 'cond-block-1' (extracted from 'cond-block-1₍0₎')
|
||||||
|
// 2. Source block output is found at virtual ID 'agent-block-1₍0₎' (same branch)
|
||||||
|
// 3. The evaluation context contains { response: 'hello from branch 0' }
|
||||||
|
expect((result as any).conditionResult).toBe(true)
|
||||||
|
expect((result as any).selectedOption).toBe('cond1')
|
||||||
|
expect((result as any).selectedPath).toEqual({
|
||||||
|
blockId: 'target-block-1',
|
||||||
|
blockType: 'api',
|
||||||
|
blockTitle: 'Target',
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should find correct source block output in parallel branch context', async () => {
|
||||||
|
// Test that when multiple branches exist, the correct branch output is used
|
||||||
|
const parallelConditionBlock: SerializedBlock = {
|
||||||
|
id: 'cond-block-1₍1₎', // Virtual ID for branch 1
|
||||||
|
metadata: { id: 'condition', name: 'Condition' },
|
||||||
|
position: { x: 0, y: 0 },
|
||||||
|
config: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
const parallelWorkflow: SerializedWorkflow = {
|
||||||
|
blocks: [
|
||||||
|
{
|
||||||
|
id: 'agent-block-1',
|
||||||
|
metadata: { id: 'agent', name: 'Agent' },
|
||||||
|
position: { x: 0, y: 0 },
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'cond-block-1',
|
||||||
|
metadata: { id: 'condition', name: 'Condition' },
|
||||||
|
position: { x: 100, y: 0 },
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'target-block-1',
|
||||||
|
metadata: { id: 'api', name: 'Target' },
|
||||||
|
position: { x: 200, y: 0 },
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
connections: [
|
||||||
|
{ source: 'agent-block-1', target: 'cond-block-1' },
|
||||||
|
{ source: 'cond-block-1', target: 'target-block-1', sourceHandle: 'condition-cond1' },
|
||||||
|
],
|
||||||
|
loops: [],
|
||||||
|
parallels: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiple branches have executed - each has different output
|
||||||
|
const parallelBlockStates = new Map<string, BlockState>([
|
||||||
|
['agent-block-1₍0₎', { output: { value: 10 }, executed: true }],
|
||||||
|
['agent-block-1₍1₎', { output: { value: 25 }, executed: true }], // Branch 1 has value 25
|
||||||
|
['agent-block-1₍2₎', { output: { value: 5 }, executed: true }],
|
||||||
|
])
|
||||||
|
|
||||||
|
const parallelContext: ExecutionContext = {
|
||||||
|
workflowId: 'test-workflow-id',
|
||||||
|
workspaceId: 'test-workspace-id',
|
||||||
|
workflow: parallelWorkflow,
|
||||||
|
blockStates: parallelBlockStates,
|
||||||
|
blockLogs: [],
|
||||||
|
completedBlocks: new Set(),
|
||||||
|
decisions: {
|
||||||
|
router: new Map(),
|
||||||
|
condition: new Map(),
|
||||||
|
},
|
||||||
|
environmentVariables: {},
|
||||||
|
workflowVariables: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Condition checks if value > 20 - should be true for branch 1 (value=25)
|
||||||
|
const conditions = [
|
||||||
|
{ id: 'cond1', title: 'if', value: 'context.value > 20' },
|
||||||
|
{ id: 'else1', title: 'else', value: '' },
|
||||||
|
]
|
||||||
|
const inputs = { conditions: JSON.stringify(conditions) }
|
||||||
|
|
||||||
|
const result = await handler.execute(parallelContext, parallelConditionBlock, inputs)
|
||||||
|
|
||||||
|
// Should evaluate using branch 1's data (value=25), not branch 0 (value=10) or branch 2 (value=5)
|
||||||
|
expect((result as any).conditionResult).toBe(true)
|
||||||
|
expect((result as any).selectedOption).toBe('cond1')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should fall back to else when condition is false in parallel branch', async () => {
|
||||||
|
const parallelConditionBlock: SerializedBlock = {
|
||||||
|
id: 'cond-block-1₍2₎', // Virtual ID for branch 2
|
||||||
|
metadata: { id: 'condition', name: 'Condition' },
|
||||||
|
position: { x: 0, y: 0 },
|
||||||
|
config: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
const parallelWorkflow: SerializedWorkflow = {
|
||||||
|
blocks: [
|
||||||
|
{
|
||||||
|
id: 'agent-block-1',
|
||||||
|
metadata: { id: 'agent', name: 'Agent' },
|
||||||
|
position: { x: 0, y: 0 },
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'cond-block-1',
|
||||||
|
metadata: { id: 'condition', name: 'Condition' },
|
||||||
|
position: { x: 100, y: 0 },
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'target-true',
|
||||||
|
metadata: { id: 'api', name: 'True Path' },
|
||||||
|
position: { x: 200, y: 0 },
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'target-false',
|
||||||
|
metadata: { id: 'api', name: 'False Path' },
|
||||||
|
position: { x: 200, y: 100 },
|
||||||
|
config: {},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
connections: [
|
||||||
|
{ source: 'agent-block-1', target: 'cond-block-1' },
|
||||||
|
{ source: 'cond-block-1', target: 'target-true', sourceHandle: 'condition-cond1' },
|
||||||
|
{ source: 'cond-block-1', target: 'target-false', sourceHandle: 'condition-else1' },
|
||||||
|
],
|
||||||
|
loops: [],
|
||||||
|
parallels: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
const parallelBlockStates = new Map<string, BlockState>([
|
||||||
|
['agent-block-1₍0₎', { output: { value: 100 }, executed: true }],
|
||||||
|
['agent-block-1₍1₎', { output: { value: 50 }, executed: true }],
|
||||||
|
['agent-block-1₍2₎', { output: { value: 5 }, executed: true }], // Branch 2 has value 5
|
||||||
|
])
|
||||||
|
|
||||||
|
const parallelContext: ExecutionContext = {
|
||||||
|
workflowId: 'test-workflow-id',
|
||||||
|
workspaceId: 'test-workspace-id',
|
||||||
|
workflow: parallelWorkflow,
|
||||||
|
blockStates: parallelBlockStates,
|
||||||
|
blockLogs: [],
|
||||||
|
completedBlocks: new Set(),
|
||||||
|
decisions: {
|
||||||
|
router: new Map(),
|
||||||
|
condition: new Map(),
|
||||||
|
},
|
||||||
|
environmentVariables: {},
|
||||||
|
workflowVariables: {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Condition checks if value > 20 - should be false for branch 2 (value=5)
|
||||||
|
const conditions = [
|
||||||
|
{ id: 'cond1', title: 'if', value: 'context.value > 20' },
|
||||||
|
{ id: 'else1', title: 'else', value: '' },
|
||||||
|
]
|
||||||
|
const inputs = { conditions: JSON.stringify(conditions) }
|
||||||
|
|
||||||
|
const result = await handler.execute(parallelContext, parallelConditionBlock, inputs)
|
||||||
|
|
||||||
|
// Should fall back to else path because branch 2's value (5) is not > 20
|
||||||
|
expect((result as any).conditionResult).toBe(true)
|
||||||
|
expect((result as any).selectedOption).toBe('else1')
|
||||||
|
expect((result as any).selectedPath.blockId).toBe('target-false')
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -3,6 +3,12 @@ import type { BlockOutput } from '@/blocks/types'
|
|||||||
import { BlockType, CONDITION, DEFAULTS, EDGE } from '@/executor/constants'
|
import { BlockType, CONDITION, DEFAULTS, EDGE } from '@/executor/constants'
|
||||||
import type { BlockHandler, ExecutionContext } from '@/executor/types'
|
import type { BlockHandler, ExecutionContext } from '@/executor/types'
|
||||||
import { collectBlockData } from '@/executor/utils/block-data'
|
import { collectBlockData } from '@/executor/utils/block-data'
|
||||||
|
import {
|
||||||
|
buildBranchNodeId,
|
||||||
|
extractBaseBlockId,
|
||||||
|
extractBranchIndex,
|
||||||
|
isBranchNodeId,
|
||||||
|
} from '@/executor/utils/subflow-utils'
|
||||||
import type { SerializedBlock } from '@/serializer/types'
|
import type { SerializedBlock } from '@/serializer/types'
|
||||||
import { executeTool } from '@/tools'
|
import { executeTool } from '@/tools'
|
||||||
|
|
||||||
@@ -18,7 +24,8 @@ const CONDITION_TIMEOUT_MS = 5000
|
|||||||
export async function evaluateConditionExpression(
|
export async function evaluateConditionExpression(
|
||||||
ctx: ExecutionContext,
|
ctx: ExecutionContext,
|
||||||
conditionExpression: string,
|
conditionExpression: string,
|
||||||
providedEvalContext?: Record<string, any>
|
providedEvalContext?: Record<string, any>,
|
||||||
|
currentNodeId?: string
|
||||||
): Promise<boolean> {
|
): Promise<boolean> {
|
||||||
const evalContext = providedEvalContext || {}
|
const evalContext = providedEvalContext || {}
|
||||||
|
|
||||||
@@ -26,7 +33,7 @@ export async function evaluateConditionExpression(
|
|||||||
const contextSetup = `const context = ${JSON.stringify(evalContext)};`
|
const contextSetup = `const context = ${JSON.stringify(evalContext)};`
|
||||||
const code = `${contextSetup}\nreturn Boolean(${conditionExpression})`
|
const code = `${contextSetup}\nreturn Boolean(${conditionExpression})`
|
||||||
|
|
||||||
const { blockData, blockNameMapping, blockOutputSchemas } = collectBlockData(ctx)
|
const { blockData, blockNameMapping, blockOutputSchemas } = collectBlockData(ctx, currentNodeId)
|
||||||
|
|
||||||
const result = await executeTool(
|
const result = await executeTool(
|
||||||
'function_execute',
|
'function_execute',
|
||||||
@@ -83,7 +90,19 @@ export class ConditionBlockHandler implements BlockHandler {
|
|||||||
): Promise<BlockOutput> {
|
): Promise<BlockOutput> {
|
||||||
const conditions = this.parseConditions(inputs.conditions)
|
const conditions = this.parseConditions(inputs.conditions)
|
||||||
|
|
||||||
const sourceBlockId = ctx.workflow?.connections.find((conn) => conn.target === block.id)?.source
|
const baseBlockId = extractBaseBlockId(block.id)
|
||||||
|
const branchIndex = isBranchNodeId(block.id) ? extractBranchIndex(block.id) : null
|
||||||
|
|
||||||
|
const sourceConnection = ctx.workflow?.connections.find((conn) => conn.target === baseBlockId)
|
||||||
|
let sourceBlockId = sourceConnection?.source
|
||||||
|
|
||||||
|
if (sourceBlockId && branchIndex !== null) {
|
||||||
|
const virtualSourceId = buildBranchNodeId(sourceBlockId, branchIndex)
|
||||||
|
if (ctx.blockStates.has(virtualSourceId)) {
|
||||||
|
sourceBlockId = virtualSourceId
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const evalContext = this.buildEvaluationContext(ctx, sourceBlockId)
|
const evalContext = this.buildEvaluationContext(ctx, sourceBlockId)
|
||||||
const rawSourceOutput = sourceBlockId ? ctx.blockStates.get(sourceBlockId)?.output : null
|
const rawSourceOutput = sourceBlockId ? ctx.blockStates.get(sourceBlockId)?.output : null
|
||||||
|
|
||||||
@@ -91,13 +110,16 @@ export class ConditionBlockHandler implements BlockHandler {
|
|||||||
// thinking this block is pausing (it was already resumed by the HITL block)
|
// thinking this block is pausing (it was already resumed by the HITL block)
|
||||||
const sourceOutput = this.filterPauseMetadata(rawSourceOutput)
|
const sourceOutput = this.filterPauseMetadata(rawSourceOutput)
|
||||||
|
|
||||||
const outgoingConnections = ctx.workflow?.connections.filter((conn) => conn.source === block.id)
|
const outgoingConnections = ctx.workflow?.connections.filter(
|
||||||
|
(conn) => conn.source === baseBlockId
|
||||||
|
)
|
||||||
|
|
||||||
const { selectedConnection, selectedCondition } = await this.evaluateConditions(
|
const { selectedConnection, selectedCondition } = await this.evaluateConditions(
|
||||||
conditions,
|
conditions,
|
||||||
outgoingConnections || [],
|
outgoingConnections || [],
|
||||||
evalContext,
|
evalContext,
|
||||||
ctx
|
ctx,
|
||||||
|
block.id
|
||||||
)
|
)
|
||||||
|
|
||||||
if (!selectedConnection || !selectedCondition) {
|
if (!selectedConnection || !selectedCondition) {
|
||||||
@@ -170,7 +192,8 @@ export class ConditionBlockHandler implements BlockHandler {
|
|||||||
conditions: Array<{ id: string; title: string; value: string }>,
|
conditions: Array<{ id: string; title: string; value: string }>,
|
||||||
outgoingConnections: Array<{ source: string; target: string; sourceHandle?: string }>,
|
outgoingConnections: Array<{ source: string; target: string; sourceHandle?: string }>,
|
||||||
evalContext: Record<string, any>,
|
evalContext: Record<string, any>,
|
||||||
ctx: ExecutionContext
|
ctx: ExecutionContext,
|
||||||
|
currentNodeId?: string
|
||||||
): Promise<{
|
): Promise<{
|
||||||
selectedConnection: { target: string; sourceHandle?: string } | null
|
selectedConnection: { target: string; sourceHandle?: string } | null
|
||||||
selectedCondition: { id: string; title: string; value: string } | null
|
selectedCondition: { id: string; title: string; value: string } | null
|
||||||
@@ -189,7 +212,8 @@ export class ConditionBlockHandler implements BlockHandler {
|
|||||||
const conditionMet = await evaluateConditionExpression(
|
const conditionMet = await evaluateConditionExpression(
|
||||||
ctx,
|
ctx,
|
||||||
conditionValueString,
|
conditionValueString,
|
||||||
evalContext
|
evalContext,
|
||||||
|
currentNodeId
|
||||||
)
|
)
|
||||||
|
|
||||||
if (conditionMet) {
|
if (conditionMet) {
|
||||||
|
|||||||
@@ -2,6 +2,11 @@ import { normalizeInputFormatValue } from '@/lib/workflows/input-format'
|
|||||||
import { isTriggerBehavior, normalizeName } from '@/executor/constants'
|
import { isTriggerBehavior, normalizeName } from '@/executor/constants'
|
||||||
import type { ExecutionContext } from '@/executor/types'
|
import type { ExecutionContext } from '@/executor/types'
|
||||||
import type { OutputSchema } from '@/executor/utils/block-reference'
|
import type { OutputSchema } from '@/executor/utils/block-reference'
|
||||||
|
import {
|
||||||
|
extractBaseBlockId,
|
||||||
|
extractBranchIndex,
|
||||||
|
isBranchNodeId,
|
||||||
|
} from '@/executor/utils/subflow-utils'
|
||||||
import type { SerializedBlock } from '@/serializer/types'
|
import type { SerializedBlock } from '@/serializer/types'
|
||||||
import type { ToolConfig } from '@/tools/types'
|
import type { ToolConfig } from '@/tools/types'
|
||||||
import { getTool } from '@/tools/utils'
|
import { getTool } from '@/tools/utils'
|
||||||
@@ -86,14 +91,30 @@ export function getBlockSchema(
|
|||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
export function collectBlockData(ctx: ExecutionContext): BlockDataCollection {
|
export function collectBlockData(
|
||||||
|
ctx: ExecutionContext,
|
||||||
|
currentNodeId?: string
|
||||||
|
): BlockDataCollection {
|
||||||
const blockData: Record<string, unknown> = {}
|
const blockData: Record<string, unknown> = {}
|
||||||
const blockNameMapping: Record<string, string> = {}
|
const blockNameMapping: Record<string, string> = {}
|
||||||
const blockOutputSchemas: Record<string, OutputSchema> = {}
|
const blockOutputSchemas: Record<string, OutputSchema> = {}
|
||||||
|
|
||||||
|
const branchIndex =
|
||||||
|
currentNodeId && isBranchNodeId(currentNodeId) ? extractBranchIndex(currentNodeId) : null
|
||||||
|
|
||||||
for (const [id, state] of ctx.blockStates.entries()) {
|
for (const [id, state] of ctx.blockStates.entries()) {
|
||||||
if (state.output !== undefined) {
|
if (state.output !== undefined) {
|
||||||
blockData[id] = state.output
|
blockData[id] = state.output
|
||||||
|
|
||||||
|
if (branchIndex !== null && isBranchNodeId(id)) {
|
||||||
|
const stateBranchIndex = extractBranchIndex(id)
|
||||||
|
if (stateBranchIndex === branchIndex) {
|
||||||
|
const baseId = extractBaseBlockId(id)
|
||||||
|
if (blockData[baseId] === undefined) {
|
||||||
|
blockData[baseId] = state.output
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ import {
|
|||||||
ensureOrganizationForTeamSubscription,
|
ensureOrganizationForTeamSubscription,
|
||||||
syncSubscriptionUsageLimits,
|
syncSubscriptionUsageLimits,
|
||||||
} from '@/lib/billing/organization'
|
} from '@/lib/billing/organization'
|
||||||
import { getPlans } from '@/lib/billing/plans'
|
import { getPlans, resolvePlanFromStripeSubscription } from '@/lib/billing/plans'
|
||||||
import { syncSeatsFromStripeQuantity } from '@/lib/billing/validation/seat-management'
|
import { syncSeatsFromStripeQuantity } from '@/lib/billing/validation/seat-management'
|
||||||
import { handleChargeDispute, handleDisputeClosed } from '@/lib/billing/webhooks/disputes'
|
import { handleChargeDispute, handleDisputeClosed } from '@/lib/billing/webhooks/disputes'
|
||||||
import { handleManualEnterpriseSubscription } from '@/lib/billing/webhooks/enterprise'
|
import { handleManualEnterpriseSubscription } from '@/lib/billing/webhooks/enterprise'
|
||||||
@@ -2641,29 +2641,42 @@ export const auth = betterAuth({
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
onSubscriptionComplete: async ({
|
onSubscriptionComplete: async ({
|
||||||
|
stripeSubscription,
|
||||||
subscription,
|
subscription,
|
||||||
}: {
|
}: {
|
||||||
event: Stripe.Event
|
event: Stripe.Event
|
||||||
stripeSubscription: Stripe.Subscription
|
stripeSubscription: Stripe.Subscription
|
||||||
subscription: any
|
subscription: any
|
||||||
}) => {
|
}) => {
|
||||||
|
const { priceId, planFromStripe, isTeamPlan } =
|
||||||
|
resolvePlanFromStripeSubscription(stripeSubscription)
|
||||||
|
|
||||||
logger.info('[onSubscriptionComplete] Subscription created', {
|
logger.info('[onSubscriptionComplete] Subscription created', {
|
||||||
subscriptionId: subscription.id,
|
subscriptionId: subscription.id,
|
||||||
referenceId: subscription.referenceId,
|
referenceId: subscription.referenceId,
|
||||||
plan: subscription.plan,
|
dbPlan: subscription.plan,
|
||||||
|
planFromStripe,
|
||||||
|
priceId,
|
||||||
status: subscription.status,
|
status: subscription.status,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const subscriptionForOrgCreation = isTeamPlan
|
||||||
|
? { ...subscription, plan: 'team' }
|
||||||
|
: subscription
|
||||||
|
|
||||||
let resolvedSubscription = subscription
|
let resolvedSubscription = subscription
|
||||||
try {
|
try {
|
||||||
resolvedSubscription = await ensureOrganizationForTeamSubscription(subscription)
|
resolvedSubscription = await ensureOrganizationForTeamSubscription(
|
||||||
|
subscriptionForOrgCreation
|
||||||
|
)
|
||||||
} catch (orgError) {
|
} catch (orgError) {
|
||||||
logger.error(
|
logger.error(
|
||||||
'[onSubscriptionComplete] Failed to ensure organization for team subscription',
|
'[onSubscriptionComplete] Failed to ensure organization for team subscription',
|
||||||
{
|
{
|
||||||
subscriptionId: subscription.id,
|
subscriptionId: subscription.id,
|
||||||
referenceId: subscription.referenceId,
|
referenceId: subscription.referenceId,
|
||||||
plan: subscription.plan,
|
dbPlan: subscription.plan,
|
||||||
|
planFromStripe,
|
||||||
error: orgError instanceof Error ? orgError.message : String(orgError),
|
error: orgError instanceof Error ? orgError.message : String(orgError),
|
||||||
stack: orgError instanceof Error ? orgError.stack : undefined,
|
stack: orgError instanceof Error ? orgError.stack : undefined,
|
||||||
}
|
}
|
||||||
@@ -2684,22 +2697,67 @@ export const auth = betterAuth({
|
|||||||
event: Stripe.Event
|
event: Stripe.Event
|
||||||
subscription: any
|
subscription: any
|
||||||
}) => {
|
}) => {
|
||||||
|
const stripeSubscription = event.data.object as Stripe.Subscription
|
||||||
|
const { priceId, planFromStripe, isTeamPlan } =
|
||||||
|
resolvePlanFromStripeSubscription(stripeSubscription)
|
||||||
|
|
||||||
|
if (priceId && !planFromStripe) {
|
||||||
|
logger.warn(
|
||||||
|
'[onSubscriptionUpdate] Could not determine plan from Stripe price ID',
|
||||||
|
{
|
||||||
|
subscriptionId: subscription.id,
|
||||||
|
priceId,
|
||||||
|
dbPlan: subscription.plan,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const isUpgradeToTeam =
|
||||||
|
isTeamPlan &&
|
||||||
|
subscription.plan !== 'team' &&
|
||||||
|
!subscription.referenceId.startsWith('org_')
|
||||||
|
|
||||||
|
const effectivePlanForTeamFeatures = planFromStripe ?? subscription.plan
|
||||||
|
|
||||||
logger.info('[onSubscriptionUpdate] Subscription updated', {
|
logger.info('[onSubscriptionUpdate] Subscription updated', {
|
||||||
subscriptionId: subscription.id,
|
subscriptionId: subscription.id,
|
||||||
status: subscription.status,
|
status: subscription.status,
|
||||||
plan: subscription.plan,
|
dbPlan: subscription.plan,
|
||||||
|
planFromStripe,
|
||||||
|
isUpgradeToTeam,
|
||||||
|
referenceId: subscription.referenceId,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
const subscriptionForOrgCreation = isUpgradeToTeam
|
||||||
|
? { ...subscription, plan: 'team' }
|
||||||
|
: subscription
|
||||||
|
|
||||||
let resolvedSubscription = subscription
|
let resolvedSubscription = subscription
|
||||||
try {
|
try {
|
||||||
resolvedSubscription = await ensureOrganizationForTeamSubscription(subscription)
|
resolvedSubscription = await ensureOrganizationForTeamSubscription(
|
||||||
|
subscriptionForOrgCreation
|
||||||
|
)
|
||||||
|
|
||||||
|
if (isUpgradeToTeam) {
|
||||||
|
logger.info(
|
||||||
|
'[onSubscriptionUpdate] Detected Pro -> Team upgrade, ensured organization creation',
|
||||||
|
{
|
||||||
|
subscriptionId: subscription.id,
|
||||||
|
originalPlan: subscription.plan,
|
||||||
|
newPlan: planFromStripe,
|
||||||
|
resolvedReferenceId: resolvedSubscription.referenceId,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
} catch (orgError) {
|
} catch (orgError) {
|
||||||
logger.error(
|
logger.error(
|
||||||
'[onSubscriptionUpdate] Failed to ensure organization for team subscription',
|
'[onSubscriptionUpdate] Failed to ensure organization for team subscription',
|
||||||
{
|
{
|
||||||
subscriptionId: subscription.id,
|
subscriptionId: subscription.id,
|
||||||
referenceId: subscription.referenceId,
|
referenceId: subscription.referenceId,
|
||||||
plan: subscription.plan,
|
dbPlan: subscription.plan,
|
||||||
|
planFromStripe,
|
||||||
|
isUpgradeToTeam,
|
||||||
error: orgError instanceof Error ? orgError.message : String(orgError),
|
error: orgError instanceof Error ? orgError.message : String(orgError),
|
||||||
stack: orgError instanceof Error ? orgError.stack : undefined,
|
stack: orgError instanceof Error ? orgError.stack : undefined,
|
||||||
}
|
}
|
||||||
@@ -2717,9 +2775,8 @@ export const auth = betterAuth({
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if (resolvedSubscription.plan === 'team') {
|
if (effectivePlanForTeamFeatures === 'team') {
|
||||||
try {
|
try {
|
||||||
const stripeSubscription = event.data.object as Stripe.Subscription
|
|
||||||
const quantity = stripeSubscription.items?.data?.[0]?.quantity || 1
|
const quantity = stripeSubscription.items?.data?.[0]?.quantity || 1
|
||||||
|
|
||||||
const result = await syncSeatsFromStripeQuantity(
|
const result = await syncSeatsFromStripeQuantity(
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import type Stripe from 'stripe'
|
||||||
import {
|
import {
|
||||||
getFreeTierLimit,
|
getFreeTierLimit,
|
||||||
getProTierLimit,
|
getProTierLimit,
|
||||||
@@ -56,6 +57,13 @@ export function getPlanByName(planName: string): BillingPlan | undefined {
|
|||||||
return getPlans().find((plan) => plan.name === planName)
|
return getPlans().find((plan) => plan.name === planName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get a specific plan by Stripe price ID
|
||||||
|
*/
|
||||||
|
export function getPlanByPriceId(priceId: string): BillingPlan | undefined {
|
||||||
|
return getPlans().find((plan) => plan.priceId === priceId)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get plan limits for a given plan name
|
* Get plan limits for a given plan name
|
||||||
*/
|
*/
|
||||||
@@ -63,3 +71,26 @@ export function getPlanLimits(planName: string): number {
|
|||||||
const plan = getPlanByName(planName)
|
const plan = getPlanByName(planName)
|
||||||
return plan?.limits.cost ?? getFreeTierLimit()
|
return plan?.limits.cost ?? getFreeTierLimit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface StripePlanResolution {
|
||||||
|
priceId: string | undefined
|
||||||
|
planFromStripe: string | null
|
||||||
|
isTeamPlan: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Resolve plan information from a Stripe subscription object.
|
||||||
|
* Used to get the authoritative plan from Stripe rather than relying on DB state.
|
||||||
|
*/
|
||||||
|
export function resolvePlanFromStripeSubscription(
|
||||||
|
stripeSubscription: Stripe.Subscription
|
||||||
|
): StripePlanResolution {
|
||||||
|
const priceId = stripeSubscription?.items?.data?.[0]?.price?.id
|
||||||
|
const plan = priceId ? getPlanByPriceId(priceId) : undefined
|
||||||
|
|
||||||
|
return {
|
||||||
|
priceId,
|
||||||
|
planFromStripe: plan?.name ?? null,
|
||||||
|
isTeamPlan: plan?.name === 'team',
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user