Compare commits

...

83 Commits

Author SHA1 Message Date
Waleed
76fac13f3d v0.3.41: wand with azure openai, generic mysql and postgres blocks 2025-08-29 19:19:29 -07:00
Waleed
a3838302e0 feat(kb): add adjustable concurrency and batching to uploads and embeddings (#1198) 2025-08-29 18:37:23 -07:00
Waleed
4310dd6c15 imporvement(pg): added wand config for writing sql queries for generic db blocks & supabase postgrest syntax (#1197)
* add parallel ai, postgres, mysql, slight modifications to dark mode styling

* bun install frozen lockfile

* new deps

* improve security, add wand to short input and update wand config
2025-08-29 18:32:07 -07:00
Waleed
813a0fb741 feat(tools): add parallel ai, postgres, mysql, slight modifications to dark mode styling (#1192)
* add parallel ai, postgres, mysql, slight modifications to dark mode styling

* bun install frozen lockfile

* new deps
2025-08-29 17:25:02 -07:00
Waleed
7e23e942d7 fix(billing-ui): open settings when enterprise sub folks press usage indicator (#1194) 2025-08-29 16:11:32 -07:00
Siddharth Ganesan
7fcbafab97 Use direct fetch (#1193) 2025-08-29 16:10:36 -07:00
Siddharth Ganesan
056dc2879c Fix/wand (#1191)
* Switch to node

* Refactor
2025-08-29 15:50:26 -07:00
Siddharth Ganesan
1aec32b7e2 Switch to node (#1190) 2025-08-29 15:18:07 -07:00
Vikhyath Mondreti
316c9704af Merge pull request #1189 from simstudioai/staging
fix(deps): revert dependencies to before pg block was added
2025-08-29 14:28:31 -07:00
Vikhyath Mondreti
4e3a3bd1b1 run bun install 2025-08-29 14:23:31 -07:00
Vikhyath Mondreti
36773e8cdb Revert "feat(integrations): added parallel AI, mySQL, and postgres block/tools (#1126)"
This reverts commit 766279bb8b.
2025-08-29 14:14:45 -07:00
Vikhyath Mondreti
7ac89e35a1 revert(dep-changes): revert drizzle-orm version and change CI yaml script 2025-08-29 13:51:36 -07:00
Vikhyath Mondreti
faa094195a change bun install to be based on frozen-lockfile flag"
"
2025-08-29 13:42:20 -07:00
Vikhyath Mondreti
69319d21cd revert drizzle-orm version 2025-08-29 13:36:57 -07:00
Vikhyath Mondreti
8362fd7a83 remove bun lock 2025-08-29 13:34:46 -07:00
Vikhyath Mondreti
39ad793a9a revert package.json 2025-08-29 13:34:19 -07:00
Waleed
921c755711 v0.3.40: drizzle fixes, custom postgres port support 2025-08-29 10:24:40 -07:00
Waleed
41ec75fcad fix(pg): fix POSTGRES_PORT envvar to map external port to 5432 internally (#1187) 2025-08-29 10:11:37 -07:00
Waleed
f2502f5e48 fix(database): revert changes related to db URL (#1185)
* fix(database): revert changes related to db URL

* cleanup
2025-08-29 09:33:40 -07:00
Vikhyath Mondreti
f3c4f7e20a fix 2025-08-29 00:35:15 -07:00
Vikhyath Mondreti
f578f43c9a graceful exit for drizzle migration 2025-08-29 00:25:47 -07:00
Vikhyath Mondreti
5c73038023 fix(db): attempt parsing cert and db url separately (#1183) 2025-08-29 00:17:05 -07:00
Waleed
92132024ca fix(db): accept self-signed certs (#1181) 2025-08-28 23:19:43 -07:00
Waleed
ed11456de3 fix(db): accept self-signed certs (#1181) 2025-08-28 23:19:02 -07:00
Waleed
8739a3d378 fix(ssl): add envvar for optional ssl cert (#1179) 2025-08-28 23:11:21 -07:00
Waleed
ca015deea9 fix(ssl): add envvar for optional ssl cert (#1179) 2025-08-28 23:00:43 -07:00
Waleed
fd6d927228 v0.3.40: copilot improvements, knowledgebase improvements, security improvements, billing fixes 2025-08-28 22:00:58 -07:00
Adam Gough
6ac59a3264 Revert "fix(cursor-and-input): fixes cursor and input canvas error (#1168)" (#1178)
This reverts commit aa84c75360.
2025-08-28 21:06:30 -07:00
Adam Gough
aa84c75360 fix(cursor-and-input): fixes cursor and input canvas error (#1168)
* fixed long input

* lint

* fix gray canvas

* fixed auto-pan

* remove duplicate useEffect

* fix auto-pan for wide mode

* removed any

---------

Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
2025-08-28 20:17:10 -07:00
Vikhyath Mondreti
ebb8cf8bf9 fix(slack): set depends on for slack channel channel subblock (#1177)
* fix(slack): set depends on for slack channel

* use foreign credential check

* fix

* fix clearing of block
2025-08-28 20:11:30 -07:00
Siddharth Ganesan
cadfcdbfbd Fix (#1176) 2025-08-28 19:21:29 -07:00
Vikhyath Mondreti
7d62c200fa feat(openrouter): add open router to model block (#1172)
* feat(openrouter): add open router to model block

* improvement(openrouter): streaming fix, temperature fix

* pr comments

---------

Co-authored-by: waleedlatif1 <walif6@gmail.com>
2025-08-28 18:47:36 -07:00
Siddharth Ganesan
df646256b3 Revert "feat(debug): create debugger (#1174)" (#1175)
This reverts commit 7c73f5ffe0.
2025-08-28 18:46:40 -07:00
Siddharth Ganesan
7c73f5ffe0 feat(debug): create debugger (#1174)
* Updates

* Updates

* Updates

* Checkpoint

* Checkpoint

* Checkpoitn

* Var improvements

* Fixes

* Execution status

* UI improvements

* Ui updates

* Fix

* Fix scoping

* Fix workflow vars

* Fix env vars

* Remove number styling

* Variable highlighting

* Updates

* Update

* Fix resume

* Stuff

* Breakpoint ui

* Ui

* Ui updates

* Loops and parallels

* HIde env vars

* Checkpoint

* Stuff

* Panel toggle

* Lint
2025-08-28 18:19:20 -07:00
Waleed
bb5f40a027 feat(pg): added ability to customize postgres port when running containerized app (#1173) 2025-08-28 17:16:24 -07:00
Waleed
5ae5429296 chore(deps): upgrade trigger.dev in gh action (#1171) 2025-08-28 17:08:59 -07:00
Waleed
fcf128f6db improvement(knowledge): remove innerJoin and add id identifiers to results, updated docs (#1170)
* improvement(knowledge): remove innerJoin and add id identifiers to results, updated docs

* cleanup

* add documentName to upload chunk op as well
2025-08-28 17:04:31 -07:00
Vikhyath Mondreti
56543dafb4 fix(billing): usage tracking cleanup, shared pool of limits for team/enterprise (#1131)
* fix(billing): team usage tracking cleanup, shared pool of limits for team

* address greptile commments

* fix lint

* remove usage of deprecated cols"

* update periodStart and periodEnd correctly

* fix lint

* fix type issue

* fix(billing): cleaned up billing, still more work to do on UI and population of data and consolidation

* fix upgrade

* cleanup

* progress

* works

* Remove 78th migration to prepare for merge with staging

* fix migration conflict

* remove useless test file

* fix

* Fix undefined seat pricing display and handle cancelled subscription seat updates

* cleanup code

* cleanup to use helpers for pulling pricing limits

* cleanup more things

* cleanup

* restore environment ts file

* remove unused files

* fix(team-management): fix team management UI, consolidate components

* use session data instead of subscription data in settings navigation

* remove unused code

* fix UI for enterprise plans

* added enterprise plan support

* progress

* billing state machine

* split overage and base into separate invoices

* fix badge logic

---------

Co-authored-by: waleedlatif1 <walif6@gmail.com>
2025-08-28 17:00:48 -07:00
Emir Karabeg
7cc4574913 improvement(knowledge): search returns document name (#1167) 2025-08-28 16:07:22 -07:00
Waleed
3f900947ce improvement(kb): use trigger.dev for kb tasks (#1166) 2025-08-28 12:14:31 -07:00
Waleed
bda8ee772a fix(security): strengthen email invite validation logic, fix invite page UI (#1162)
* fix(security): strengthen email ivnite validation logic, fix invite page UI

* ui
2025-08-28 00:03:03 -07:00
Siddharth Ganesan
104d34cc9e fix(copilot): context filtering (#1160)
* Add filter

* Scope kb and chats

* Lint

* Remove comments

* Lint
2025-08-27 22:57:28 -07:00
Siddharth Ganesan
06e9a6b302 feat(copilot): context (#1157)
* Copilot updates

* Set/get vars

* Credentials opener v1

* Progress

* Checkpoint?

* Context v1

* Workflow references

* Add knowledge base context

* Blocks

* Templates

* Much better pills

* workflow updates

* Major ui

* Workflow box colors

* Much i mproved ui

* Improvements

* Much better

* Add @ icon

* Welcome page

* Update tool names

* Matches

* UPdate ordering

* Good sort

* Good @ handling

* Update placeholder

* Updates

* Lint

* Almost there

* Wrapped up?

* Lint

* Builid error fix

* Build fix?

* Lint

* Fix load vars
2025-08-27 21:07:51 -07:00
Waleed
fed4e507cc fix(signup): refetch session data on signup (#1155) 2025-08-27 20:01:04 -07:00
Waleed
389456e0f3 fix(envvars): fix split for pasting envvars with query params (#1156) 2025-08-27 19:55:54 -07:00
Vikhyath Mondreti
c720f23d9b fix(sockets): useCollabWorkflow cleanup, variables store logic simplification (#1154)
* fix(sockets): useCollabWorkflow cleanup, variables store logic simplification

* remove unecessary check
2025-08-27 17:11:39 -07:00
Vikhyath Mondreti
89f7d2b943 improvement(sockets): cleanup debounce logic + add flush mechanism to… (#1152)
* improvement(sockets): cleanup debounce logic + add flush mechanism to not lose ops

* fix optimistic update overwritten race condition

* fix

* fix forever stuck in processing
2025-08-27 11:35:20 -07:00
Emir Karabeg
923c05239c fix(auto-layout): revert (#1148) 2025-08-26 23:24:09 -07:00
Waleed
3424a338b7 fix(security): fixed SSRF vulnerability (#1149) 2025-08-26 23:11:08 -07:00
Waleed
51b1e97fa2 fix(kb-uploads): created knowledge, chunks, tags services and use redis for queueing docs in kb (#1143)
* improvement(kb): created knowledge, chunks, tags services and use redis for queueing docs in kb

* moved directories around

* cleanup

* bulk create docuemnt records after upload is completed

* fix(copilot): send api key to sim agent (#1142)

* Fix api key auth

* Lint

* ack PR comments

* added sort by functionality for headers in kb table

* updated

* test fallback from redis, fix styling

* cleanup copilot, fixed tooltips

* feat: local auto layout (#1144)

* feat: added llms.txt and robots.txt (#1145)

* fix(condition-block): edges not following blocks, duplicate issues (#1146)

* fix(condition-block): edges not following blocks, duplicate issues

* add subblock update to setActiveWorkflow

* Update apps/sim/app/workspace/[workspaceId]/w/[workflowId]/components/workflow-block/components/sub-block/components/condition-input.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* fix dependency array

* fix(copilot-cleanup): support azure blob upload in copilot, remove dead code & consolidate other copilot files (#1147)

* cleanup

* support azure blob image upload

* imports cleanup

* PR comments

* ack PR comments

* fix key validation

* improvement(forwarding+excel): added forwarding and improve excel read (#1136)

* added forwarding for outlook

* lint

* improved excel sheet read

* addressed greptile

* fixed bodytext getting truncated

* fixed any type

* added html func

---------

Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>

* revert agent const

* update docs

---------

Co-authored-by: Siddharth Ganesan <33737564+Sg312@users.noreply.github.com>
Co-authored-by: Emir Karabeg <78010029+emir-karabeg@users.noreply.github.com>
Co-authored-by: Vikhyath Mondreti <vikhyathvikku@gmail.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
Co-authored-by: Adam Gough <77861281+aadamgough@users.noreply.github.com>
Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
2025-08-26 22:55:18 -07:00
Adam Gough
ab74b13802 improvement(forwarding+excel): added forwarding and improve excel read (#1136)
* added forwarding for outlook

* lint

* improved excel sheet read

* addressed greptile

* fixed bodytext getting truncated

* fixed any type

* added html func

---------

Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
2025-08-26 21:18:09 -07:00
Vikhyath Mondreti
861ab1446a Merge branch 'staging' of github.com:simstudioai/sim into staging 2025-08-26 20:09:13 -07:00
Vikhyath Mondreti
e6f519a5a6 fix dependency array 2025-08-26 20:08:37 -07:00
Waleed
8226e7b40a fix(copilot-cleanup): support azure blob upload in copilot, remove dead code & consolidate other copilot files (#1147)
* cleanup

* support azure blob image upload

* imports cleanup

* PR comments

* ack PR comments

* fix key validation
2025-08-26 20:06:43 -07:00
Vikhyath Mondreti
b177b291cf fix(condition-block): edges not following blocks, duplicate issues (#1146)
* fix(condition-block): edges not following blocks, duplicate issues

* add subblock update to setActiveWorkflow

* Update apps/sim/app/workspace/[workspaceId]/w/[workflowId]/components/workflow-block/components/sub-block/components/condition-input.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-26 19:51:55 -07:00
Emir Karabeg
9c3b43325b feat: added llms.txt and robots.txt (#1145) 2025-08-26 19:04:27 -07:00
Emir Karabeg
973a5c6497 feat: local auto layout (#1144) 2025-08-26 19:03:09 -07:00
Siddharth Ganesan
78437c688e fix(copilot): send api key to sim agent (#1142)
* Fix api key auth

* Lint
2025-08-26 16:01:42 -07:00
Vikhyath Mondreti
3b74250335 fix(subblock-race-condition): check loading state correctly (#1141)
* fix(subblock-race-condition): check loading state correctly"
;

* clean up

* remove useless comments

* fix date fallback
2025-08-26 12:14:58 -07:00
Waleed
c68800c772 feat(login): add terms and privacy to signup and login pages (#1139) 2025-08-26 11:19:17 -07:00
Siddharth Ganesan
5403665fa9 Docs update (#1140) 2025-08-26 11:16:07 -07:00
Siddharth Ganesan
3d3443f68e fix(copilot): enterprise api keys (#1138)
* Copilot enterprise

* Fix validation and enterprise azure keys

* Lint

* update tests

* Update

* Lint

* Remove hardcoded ishosted

* Lint

* Updatse

* Add tests
2025-08-26 10:55:08 -07:00
Emir Karabeg
e5c0b14367 improvement(help-modal): ui/ux (#1135) 2025-08-25 19:36:38 -07:00
Siddharth Ganesan
a495516901 feat(copilot): enable azure openai and move key validation (#1134)
* Copilot enterprise

* Fix validation and enterprise azure keys

* Lint

* update tests

* Update

* Lint

* Remove hardcoded ishosted

* Lint
2025-08-25 18:03:08 -07:00
Waleed
1f9b4a8ef0 fix(wand): remove unstable__noStore and remove, add additional logs for wand generation (#1133)
* feat(wand): added additional logs for wand generation

* remove unstable__noStore
2025-08-25 16:20:41 -07:00
Waleed
3372829c30 fix(wand): remove edge runtime for wand (#1132) 2025-08-25 14:21:27 -07:00
Waleed
45372aece5 fix(files): fix vulnerabilities in file uploads/deletes (#1130)
* fix(vulnerability): fix arbitrary file deletion vuln

* fix(uploads): fix vuln during upload

* cleanup
2025-08-25 11:26:42 -07:00
Waleed Latif
ed9b9ad83f v0.3.39: billing fixes, custom tools fixes, copilot client-side migration, new tools 2025-08-24 00:18:25 -07:00
Waleed Latif
766279bb8b feat(integrations): added parallel AI, mySQL, and postgres block/tools (#1126)
* feat(integrations): added parallel ai block/tool and corresponding docs

* add postgres block

* added mysql block

* enrich docs for Postgres and MySQL

* make password fields user only for mysql and postgres

* fixed build

* ack greptile comments

* fix PR comments

* remove search_id from parallel ai

* fix parallel ai params
2025-08-23 21:43:55 -07:00
Adam Gough
1038e148c3 fix autoconnect (#1127) 2025-08-23 20:46:03 -07:00
Adam Gough
8b78200991 fix(onedrive): fixed advanced mode (#1122)
* fixed onedrive advanced mode

* removed logger

* removed loger

* added a slack instruction

* remove folderId

---------

Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
2025-08-23 19:49:13 -07:00
Siddharth Ganesan
c8f4791582 Feat/copilot client clean (#1118)
* SSE tool call v1 - not tested yet

* Handle tool call generation sse

* Add mark complete api

* copilot new progress

* Migrate get user workflow

* Run workflow migrated

* Migrate run workflow and remove some dead code

* Migrate gdrive request access

* Add server side execution logic

* Get block metadata migrated

* Build workflow progress

* Somewhat working condition, build still broken

* Stuff

* Get workflow console

* search online tool

* Set/get env vars

* oauth, gdrive list, gdrive read

* Search docs

* Build workflow update

* Edit workflow

* Migrate plan tool

* Checkoff

* Refactor

* Improvement

* checkpoint

* New store basics

* Generating adds to map

* Update

* Display v1

* Update

* Stuff

* Stuff

* Stuff

* Edit works

* Interrupt tool fixes

* Interrupt tool fixes

* Good progress

* new copilot to copilot

* Fix chat laoding

* Skip rendering of non registered tools

* Small fix

* Updates

* Updates

* Updates

* Update

* Some fixes

* Revert fixes

* run workflow

* Move to background button shows up

* User input scroll bar

* Lint

* Build errors

* Diff controls

* Restore ui

* Ui fixes

* Max mode ui

* Thinking text collapse

* Tool ui updates

* Mode selector UI

* Lint

* Ui

* Update icon

* Dummy test

* Lint
2025-08-23 18:11:10 -07:00
Vikhyath Mondreti
6c9e0ec88b improvement(logging): capture pre-execution validation errors in logging session (#1124)
* improvement(pre-exec-errors): capture pre-execution validation errors in logging session

* fix param shape for schedules

* fix naming
2025-08-23 18:08:57 -07:00
Adam Gough
bbbf1c2941 fix(teams-wh): fixed teams wh payload (#1119)
* first push

* fixed variable res

* lint

---------

Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
2025-08-23 17:59:00 -07:00
Vikhyath Mondreti
efc487a845 improvement(chat-file-upload): add visual indication of file upload exceeding limit (#1123)
* improvement(chat-file-upload): add visual indication of file upload exceeding limit

* fix duplicate error + lint

* fix lint

* fix lint
2025-08-23 17:08:41 -07:00
Vikhyath Mondreti
5786909c5e fix(tag-dropdown): arrow navigation for submenu affecting text input cursor (#1121) 2025-08-23 16:19:45 -07:00
Vikhyath Mondreti
833c5fefd5 fix(logs): fix to remove retrieval of execution of data for basic version of call (#1120) 2025-08-23 15:51:08 -07:00
Adam Gough
79dd1ccb9f fix(ux): minor ux changes (#1109)
* minor UX fixes

* changed variable collapse

* lint

---------

Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
2025-08-23 15:50:40 -07:00
Waleed Latif
730164abee fix(custom-tool): fix textarea, param dropdown for available params, validation for invalid schemas, variable resolution in custom tools and subflow tags (#1117)
* fix(custom-tools): fix text area for custom tools

* added param dropdown in agent custom tool

* add syntax highlighting for params, fix dropdown styling

* ux

* add tooltip to prevent indicate invalid json schema on schema and code tabs

* feat(custom-tool): added stricter JSON schema validation and error when saving json schema for custom tools

* fix(custom-tool): allow variable resolution in custom tools

* fix variable resolution in subflow tags

* refactored function execution to use helpers

* cleanup

* fix block variable resolution to inject at runtime

* fix highlighting code

---------

Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
2025-08-23 13:15:12 -07:00
Vikhyath Mondreti
25b2c45ec0 fix(billing): change reset user stats func to invoice payment succeeded (#1116)
* fix(billing): change reset user stats func to invoice payment succeeded

* remove nonexistent billing reason
2025-08-23 10:50:23 -07:00
Vikhyath Mondreti
780870c48e fix(billing): make subscription table source of truth for period start and period end (#1114)
* fix(billing): vercel cron not processing billing periods

* fix(billing): cleanup unused POST and fix bug with billing timing check

* make subscriptions table source of truth for dates

* update org routes

* make everything dependent on stripe webhook

---------

Co-authored-by: Waleed Latif <walif6@gmail.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Adam Gough <77861281+aadamgough@users.noreply.github.com>
Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
2025-08-23 10:25:41 -07:00
Vikhyath Mondreti
fdfa935a09 v0.3.38: billing cron job fix 2025-08-22 17:03:36 -07:00
Vikhyath Mondreti
917552f041 fix(billing): vercel cron not processing billing periods (#1112) 2025-08-22 16:52:31 -07:00
445 changed files with 41069 additions and 21626 deletions

View File

@@ -77,7 +77,7 @@ services:
- POSTGRES_PASSWORD=postgres
- POSTGRES_DB=simstudio
ports:
- "5432:5432"
- "${POSTGRES_PORT:-5432}:5432"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s

View File

@@ -26,7 +26,7 @@ jobs:
node-version: latest
- name: Install dependencies
run: bun install
run: bun install --frozen-lockfile
- name: Run tests with coverage
env:

View File

@@ -35,10 +35,10 @@ jobs:
- name: Deploy to Staging
if: github.ref == 'refs/heads/staging'
working-directory: ./apps/sim
run: npx --yes trigger.dev@4.0.0 deploy -e staging
run: npx --yes trigger.dev@4.0.1 deploy -e staging
- name: Deploy to Production
if: github.ref == 'refs/heads/main'
working-directory: ./apps/sim
run: npx --yes trigger.dev@4.0.0 deploy
run: npx --yes trigger.dev@4.0.1 deploy

View File

@@ -160,7 +160,6 @@ Copilot is a Sim-managed service. To use Copilot on a self-hosted instance:
- Go to https://sim.ai → Settings → Copilot and generate a Copilot API key
- Set `COPILOT_API_KEY` in your self-hosted environment to that value
- Host Sim on a publicly available DNS and set NEXT_PUBLIC_APP_URL and BETTER_AUTH_URL to that value ([ngrok](https://ngrok.com/))
## Tech Stack

View File

@@ -7,8 +7,6 @@ import { Callout } from 'fumadocs-ui/components/callout'
import { Card, Cards } from 'fumadocs-ui/components/card'
import { MessageCircle, Package, Zap, Infinity as InfinityIcon, Brain, BrainCircuit } from 'lucide-react'
## What is Copilot
Copilot is your in-editor assistant that helps you build, understand, and improve workflows. It can:
- **Explain**: Answer questions about Sim and your current workflow
@@ -18,35 +16,34 @@ Copilot is your in-editor assistant that helps you build, understand, and improv
<Callout type="info">
Copilot is a Sim-managed service. For self-hosted deployments, generate a Copilot API key in the hosted app (sim.ai → Settings → Copilot)
1. Go to [sim.ai](https://sim.ai) → Settings → Copilot and generate a Copilot API key
2. Set `COPILOT_API_KEY` in your self-hosted environment to that value
3. Host Sim on a publicly available DNS and set `NEXT_PUBLIC_APP_URL` and `BETTER_AUTH_URL` to that value (e.g., using ngrok)
2. Set `COPILOT_API_KEY` in your self-hosted environment to that value
</Callout>
## Modes
<Cards>
<Card title="Ask">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<MessageCircle className="h-4 w-4 text-muted-foreground" />
Ask
</span>
<div>
<p className="m-0 text-sm">
Q&A mode for explanations, guidance, and suggestions without making changes to your workflow.
</p>
</div>
}
>
<div className="m-0 text-sm">
Q&A mode for explanations, guidance, and suggestions without making changes to your workflow.
</div>
</Card>
<Card title="Agent">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<Package className="h-4 w-4 text-muted-foreground" />
Agent
</span>
<div>
<p className="m-0 text-sm">
Build-and-edit mode. Copilot proposes specific edits (add blocks, wire variables, tweak settings) and applies them when you approve.
</p>
</div>
}
>
<div className="m-0 text-sm">
Build-and-edit mode. Copilot proposes specific edits (add blocks, wire variables, tweak settings) and applies them when you approve.
</div>
</Card>
</Cards>
@@ -54,44 +51,44 @@ Copilot is your in-editor assistant that helps you build, understand, and improv
## Depth Levels
<Cards>
<Card title="Fast">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<Zap className="h-4 w-4 text-muted-foreground" />
Fast
</span>
<div>
<p className="m-0 text-sm">Quickest and cheapest. Best for small edits, simple workflows, and minor tweaks.</p>
</div>
</div>
}
>
<div className="m-0 text-sm">Quickest and cheapest. Best for small edits, simple workflows, and minor tweaks.</div>
</Card>
<Card title="Auto">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<InfinityIcon className="h-4 w-4 text-muted-foreground" />
Auto
</span>
<div>
<p className="m-0 text-sm">Balanced speed and reasoning. Recommended default for most tasks.</p>
</div>
</div>
}
>
<div className="m-0 text-sm">Balanced speed and reasoning. Recommended default for most tasks.</div>
</Card>
<Card title="Pro">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<Brain className="h-4 w-4 text-muted-foreground" />
Advanced
</span>
<div>
<p className="m-0 text-sm">More reasoning for larger workflows and complex edits while staying performant.</p>
</div>
</div>
}
>
<div className="m-0 text-sm">More reasoning for larger workflows and complex edits while staying performant.</div>
</Card>
<Card title="Max">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<BrainCircuit className="h-4 w-4 text-muted-foreground" />
Behemoth
</span>
<div>
<p className="m-0 text-sm">Maximum reasoning for deep planning, debugging, and complex architectural changes.</p>
</div>
</div>
}
>
<div className="m-0 text-sm">Maximum reasoning for deep planning, debugging, and complex architectural changes.</div>
</Card>
</Cards>

View File

@@ -33,12 +33,15 @@
"microsoft_planner",
"microsoft_teams",
"mistral_parse",
"mysql",
"notion",
"onedrive",
"openai",
"outlook",
"parallel_ai",
"perplexity",
"pinecone",
"postgresql",
"qdrant",
"reddit",
"s3",

View File

@@ -109,7 +109,7 @@ Read data from a Microsoft Excel spreadsheet
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `spreadsheetId` | string | Yes | The ID of the spreadsheet to read from |
| `range` | string | No | The range of cells to read from |
| `range` | string | No | The range of cells to read from. Accepts "SheetName!A1:B2" for explicit ranges or just "SheetName" to read the used range of that sheet. If omitted, reads the used range of the first sheet. |
#### Output

View File

@@ -0,0 +1,180 @@
---
title: MySQL
description: Connect to MySQL database
---
import { BlockInfoCard } from "@/components/ui/block-info-card"
<BlockInfoCard
type="mysql"
color="#E0E0E0"
icon={true}
iconSvg={`<svg className="block-icon"
xmlns='http://www.w3.org/2000/svg'
viewBox='0 0 25.6 25.6'
>
<path
d='M179.076 94.886c-3.568-.1-6.336.268-8.656 1.25-.668.27-1.74.27-1.828 1.116.357.355.4.936.713 1.428.535.893 1.473 2.096 2.32 2.72l2.855 2.053c1.74 1.07 3.703 1.695 5.398 2.766.982.625 1.963 1.428 2.945 2.098.5.357.803.938 1.428 1.16v-.135c-.312-.4-.402-.98-.713-1.428l-1.34-1.293c-1.293-1.74-2.9-3.258-4.64-4.506-1.428-.982-4.55-2.32-5.13-3.97l-.088-.1c.98-.1 2.14-.447 3.078-.715 1.518-.4 2.9-.312 4.46-.713l2.143-.625v-.4c-.803-.803-1.383-1.874-2.23-2.632-2.275-1.963-4.775-3.882-7.363-5.488-1.383-.892-3.168-1.473-4.64-2.23-.537-.268-1.428-.402-1.74-.848-.805-.98-1.25-2.275-1.83-3.436l-3.658-7.763c-.803-1.74-1.295-3.48-2.275-5.086-4.596-7.585-9.594-12.18-17.268-16.687-1.65-.937-3.613-1.34-5.7-1.83l-3.346-.18c-.715-.312-1.428-1.16-2.053-1.562-2.543-1.606-9.102-5.086-10.977-.5-1.205 2.9 1.785 5.755 2.8 7.228.76 1.026 1.74 2.186 2.277 3.346.3.758.4 1.562.713 2.365.713 1.963 1.383 4.15 2.32 5.98.5.937 1.025 1.92 1.65 2.767.357.5.982.714 1.115 1.517-.625.893-.668 2.23-1.025 3.347-1.607 5.042-.982 11.288 1.293 15 .715 1.115 2.4 3.57 4.686 2.632 2.008-.803 1.56-3.346 2.14-5.577.135-.535.045-.892.312-1.25v.1l1.83 3.703c1.383 2.186 3.793 4.462 5.8 5.98 1.07.803 1.918 2.187 3.256 2.677v-.135h-.088c-.268-.4-.67-.58-1.027-.892-.803-.803-1.695-1.785-2.32-2.677-1.873-2.498-3.523-5.265-4.996-8.12-.715-1.383-1.34-2.9-1.918-4.283-.27-.536-.27-1.34-.715-1.606-.67.98-1.65 1.83-2.143 3.034-.848 1.918-.936 4.283-1.248 6.737-.18.045-.1 0-.18.1-1.426-.356-1.918-1.83-2.453-3.078-1.338-3.168-1.562-8.254-.402-11.913.312-.937 1.652-3.882 1.117-4.774-.27-.848-1.16-1.338-1.652-2.008-.58-.848-1.203-1.918-1.605-2.855-1.07-2.5-1.605-5.265-2.766-7.764-.537-1.16-1.473-2.365-2.232-3.435-.848-1.205-1.783-2.053-2.453-3.48-.223-.5-.535-1.294-.178-1.83.088-.357.268-.5.623-.58.58-.5 2.232.134 2.812.4 1.65.67 3.033 1.294 4.416 2.23.625.446 1.295 1.294 2.098 1.518h.938c1.428.312 3.033.1 4.37.5 2.365.76 4.506 1.874 6.426 3.08 5.844 3.703 10.664 8.968 13.92 15.26.535 1.026.758 1.963 1.25 3.034.938 2.187 2.098 4.417 3.033 6.56.938 2.097 1.83 4.24 3.168 5.98.67.937 3.346 1.427 4.55 1.918.893.4 2.275.76 3.08 1.25 1.516.937 3.033 2.008 4.46 3.034.713.534 2.945 1.65 3.078 2.54zm-45.5-38.772a7.09 7.09 0 0 0-1.828.223v.1h.088c.357.714.982 1.205 1.428 1.83l1.027 2.142.088-.1c.625-.446.938-1.16.938-2.23-.268-.312-.312-.625-.535-.937-.268-.446-.848-.67-1.206-1.026z'
transform='matrix(.390229 0 0 .38781 -46.300037 -16.856717)'
fillRule='evenodd'
fill='#00678c'
/>
</svg>`}
/>
{/* MANUAL-CONTENT-START:intro */}
The [MySQL](https://www.mysql.com/) tool enables you to connect to any MySQL database and perform a wide range of database operations directly within your agentic workflows. With secure connection handling and flexible configuration, you can easily manage and interact with your data.
With the MySQL tool, you can:
- **Query data**: Execute SELECT queries to retrieve data from your MySQL tables using the `mysql_query` operation.
- **Insert records**: Add new rows to your tables with the `mysql_insert` operation by specifying the table and data to insert.
- **Update records**: Modify existing data in your tables using the `mysql_update` operation, providing the table, new data, and WHERE conditions.
- **Delete records**: Remove rows from your tables with the `mysql_delete` operation, specifying the table and WHERE conditions.
- **Execute raw SQL**: Run any custom SQL command using the `mysql_execute` operation for advanced use cases.
The MySQL tool is ideal for scenarios where your agents need to interact with structured data—such as automating reporting, syncing data between systems, or powering data-driven workflows. It streamlines database access, making it easy to read, write, and manage your MySQL data programmatically.
{/* MANUAL-CONTENT-END */}
## Usage Instructions
Connect to any MySQL database to execute queries, manage data, and perform database operations. Supports SELECT, INSERT, UPDATE, DELETE operations with secure connection handling.
## Tools
### `mysql_query`
Execute SELECT query on MySQL database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | MySQL server hostname or IP address |
| `port` | number | Yes | MySQL server port \(default: 3306\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | Yes | Database username |
| `password` | string | Yes | Database password |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `query` | string | Yes | SQL SELECT query to execute |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `rows` | array | Array of rows returned from the query |
| `rowCount` | number | Number of rows returned |
### `mysql_insert`
Insert new record into MySQL database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | MySQL server hostname or IP address |
| `port` | number | Yes | MySQL server port \(default: 3306\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | Yes | Database username |
| `password` | string | Yes | Database password |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `table` | string | Yes | Table name to insert into |
| `data` | object | Yes | Data to insert as key-value pairs |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `rows` | array | Array of inserted rows |
| `rowCount` | number | Number of rows inserted |
### `mysql_update`
Update existing records in MySQL database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | MySQL server hostname or IP address |
| `port` | number | Yes | MySQL server port \(default: 3306\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | Yes | Database username |
| `password` | string | Yes | Database password |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `table` | string | Yes | Table name to update |
| `data` | object | Yes | Data to update as key-value pairs |
| `where` | string | Yes | WHERE clause condition \(without WHERE keyword\) |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `rows` | array | Array of updated rows |
| `rowCount` | number | Number of rows updated |
### `mysql_delete`
Delete records from MySQL database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | MySQL server hostname or IP address |
| `port` | number | Yes | MySQL server port \(default: 3306\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | Yes | Database username |
| `password` | string | Yes | Database password |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `table` | string | Yes | Table name to delete from |
| `where` | string | Yes | WHERE clause condition \(without WHERE keyword\) |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `rows` | array | Array of deleted rows |
| `rowCount` | number | Number of rows deleted |
### `mysql_execute`
Execute raw SQL query on MySQL database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | MySQL server hostname or IP address |
| `port` | number | Yes | MySQL server port \(default: 3306\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | Yes | Database username |
| `password` | string | Yes | Database password |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `query` | string | Yes | Raw SQL query to execute |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `rows` | array | Array of rows returned from the query |
| `rowCount` | number | Number of rows affected |
## Notes
- Category: `tools`
- Type: `mysql`

View File

@@ -68,7 +68,7 @@ Upload a file to OneDrive
| `fileName` | string | Yes | The name of the file to upload |
| `content` | string | Yes | The content of the file to upload |
| `folderSelector` | string | No | Select the folder to upload the file to |
| `folderId` | string | No | The ID of the folder to upload the file to \(internal use\) |
| `manualFolderId` | string | No | Manually entered folder ID \(advanced mode\) |
#### Output
@@ -87,7 +87,7 @@ Create a new folder in OneDrive
| --------- | ---- | -------- | ----------- |
| `folderName` | string | Yes | Name of the folder to create |
| `folderSelector` | string | No | Select the parent folder to create the folder in |
| `folderId` | string | No | ID of the parent folder \(internal use\) |
| `manualFolderId` | string | No | Manually entered parent folder ID \(advanced mode\) |
#### Output
@@ -105,7 +105,7 @@ List files and folders in OneDrive
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `folderSelector` | string | No | Select the folder to list files from |
| `folderId` | string | No | The ID of the folder to list files from \(internal use\) |
| `manualFolderId` | string | No | The manually entered folder ID \(advanced mode\) |
| `query` | string | No | A query to filter the files |
| `pageSize` | number | No | The number of files to return |

View File

@@ -211,10 +211,27 @@ Read emails from Outlook
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `success` | boolean | Email read operation success status |
| `messageCount` | number | Number of emails retrieved |
| `messages` | array | Array of email message objects |
| `message` | string | Success or status message |
| `results` | array | Array of email message objects |
### `outlook_forward`
Forward an existing Outlook message to specified recipients
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `messageId` | string | Yes | The ID of the message to forward |
| `to` | string | Yes | Recipient email address\(es\), comma-separated |
| `comment` | string | No | Optional comment to include with the forwarded message |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Success or error message |
| `results` | object | Delivery result details |

View File

@@ -0,0 +1,106 @@
---
title: Parallel AI
description: Search with Parallel AI
---
import { BlockInfoCard } from "@/components/ui/block-info-card"
<BlockInfoCard
type="parallel_ai"
color="#E0E0E0"
icon={true}
iconSvg={`<svg className="block-icon"
fill='currentColor'
viewBox='0 0 271 270'
xmlns='http://www.w3.org/2000/svg'
>
<path
d='M267.804 105.65H193.828C194.026 106.814 194.187 107.996 194.349 109.178H76.6703C76.4546 110.736 76.2388 112.312 76.0591 113.87H1.63342C1.27387 116.198 0.950289 118.543 0.698608 120.925H75.3759C75.2501 122.483 75.1602 124.059 75.0703 125.617H195.949C196.003 126.781 196.057 127.962 196.093 129.144H270.68V125.384C270.195 118.651 269.242 112.061 267.804 105.65Z'
fill='#1D1C1A'
/>
<path
d='M195.949 144.401H75.0703C75.1422 145.977 75.2501 147.535 75.3759 149.093H0.698608C0.950289 151.457 1.2559 153.802 1.63342 156.148H76.0591C76.2388 157.724 76.4366 159.282 76.6703 160.84H194.349C194.187 162.022 194.008 163.186 193.828 164.367H267.804C269.242 157.957 270.195 151.367 270.68 144.634V140.874H196.093C196.057 142.055 196.003 143.219 195.949 144.401Z'
fill='#1D1C1A'
/>
<path
d='M190.628 179.642H80.3559C80.7514 181.218 81.1828 182.776 81.6143 184.334H9.30994C10.2448 186.715 11.2515 189.061 12.3121 191.389H83.7536C84.2749 192.965 84.7962 194.523 85.3535 196.08H185.594C185.163 197.262 184.732 198.426 184.282 199.608H254.519C258.6 192.177 261.98 184.316 264.604 176.114H191.455C191.185 177.296 190.898 178.46 190.61 179.642H190.628Z'
fill='#1D1C1A'
/>
<path
d='M177.666 214.883H93.3352C94.1082 216.458 94.9172 218.034 95.7441 219.574H29.8756C31.8351 221.992 33.8666 224.337 35.9699 226.63H99.6632C100.598 228.205 101.551 229.781 102.522 231.321H168.498C167.761 232.503 167.006 233.685 166.233 234.849H226.762C234.474 227.847 241.36 219.95 247.292 211.355H179.356C178.799 212.537 178.26 213.719 177.684 214.883H177.666Z'
fill='#1D1C1A'
/>
<path
d='M154.943 250.106H116.058C117.371 251.699 118.701 253.257 120.067 254.797H73.021C91.6094 264.431 112.715 269.946 135.096 270C135.24 270 135.366 270 135.492 270C135.618 270 135.761 270 135.887 270C164.04 269.911 190.178 261.28 211.805 246.56H157.748C156.813 247.742 155.878 248.924 154.925 250.088L154.943 250.106Z'
fill='#1D1C1A'
/>
<path
d='M116.059 19.9124H154.943C155.896 21.0764 156.831 22.2582 157.766 23.4401H211.823C190.179 8.72065 164.058 0.0895344 135.906 0C135.762 0 135.636 0 135.51 0C135.384 0 135.24 0 135.115 0C112.715 0.0716275 91.6277 5.56904 73.0393 15.2029H120.086C118.719 16.7429 117.389 18.3187 116.077 19.8945L116.059 19.9124Z'
fill='#1D1C1A'
/>
<path
d='M93.3356 55.1532H177.667C178.242 56.3171 178.799 57.499 179.339 58.6808H247.274C241.342 50.0855 234.457 42.1886 226.744 35.187H166.215C166.988 36.351 167.743 37.5328 168.48 38.7147H102.504C101.533 40.2726 100.58 41.8305 99.6456 43.4063H35.9523C33.831 45.6804 31.7996 48.0262 29.858 50.4616H95.7265C94.8996 52.0195 94.1086 53.5774 93.3176 55.1532H93.3356Z'
fill='#1D1C1A'
/>
<path
d='M80.3736 90.3758H190.646C190.933 91.5398 191.221 92.7216 191.491 93.9035H264.64C262.015 85.7021 258.636 77.841 254.555 70.4097H184.318C184.767 71.5736 185.199 72.7555 185.63 73.9373H85.3893C84.832 75.4952 84.2927 77.0531 83.7893 78.6289H12.3479C11.2872 80.9389 10.2805 83.2847 9.3457 85.6842H81.65C81.2186 87.2421 80.7871 88.8 80.3916 90.3758H80.3736Z'
fill='#1D1C1A'
/>
</svg>`}
/>
{/* MANUAL-CONTENT-START:intro */}
[Parallel AI](https://parallel.ai/) is an advanced web search and content extraction platform designed to deliver comprehensive, high-quality results for any query. By leveraging intelligent processing and large-scale data extraction, Parallel AI enables users and agents to access, analyze, and synthesize information from across the web with speed and accuracy.
With Parallel AI, you can:
- **Search the web intelligently**: Retrieve relevant, up-to-date information from a wide range of sources
- **Extract and summarize content**: Get concise, meaningful excerpts from web pages and documents
- **Customize search objectives**: Tailor queries to specific needs or questions for targeted results
- **Process results at scale**: Handle large volumes of search results with advanced processing options
- **Integrate with workflows**: Use Parallel AI within Sim to automate research, content gathering, and knowledge extraction
- **Control output granularity**: Specify the number of results and the amount of content per result
- **Secure API access**: Protect your searches and data with API key authentication
In Sim, the Parallel AI integration empowers your agents to perform web searches and extract content programmatically. This enables powerful automation scenarios such as real-time research, competitive analysis, content monitoring, and knowledge base creation. By connecting Sim with Parallel AI, you unlock the ability for agents to gather, process, and utilize web data as part of your automated workflows.
{/* MANUAL-CONTENT-END */}
## Usage Instructions
Search the web using Parallel AI's advanced search capabilities. Get comprehensive results with intelligent processing and content extraction.
## Tools
### `parallel_search`
Search the web using Parallel AI. Provides comprehensive search results with intelligent processing and content extraction.
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `objective` | string | Yes | The search objective or question to answer |
| `search_queries` | string | No | Optional comma-separated list of search queries to execute |
| `processor` | string | No | Processing method: base or pro \(default: base\) |
| `max_results` | number | No | Maximum number of results to return \(default: 5\) |
| `max_chars_per_result` | number | No | Maximum characters per result \(default: 1500\) |
| `apiKey` | string | Yes | Parallel AI API Key |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `results` | array | Search results with excerpts from relevant pages |
## Notes
- Category: `tools`
- Type: `parallel_ai`

View File

@@ -0,0 +1,188 @@
---
title: PostgreSQL
description: Connect to PostgreSQL database
---
import { BlockInfoCard } from "@/components/ui/block-info-card"
<BlockInfoCard
type="postgresql"
color="#336791"
icon={true}
iconSvg={`<svg className="block-icon"
viewBox='-4 0 264 264'
xmlns='http://www.w3.org/2000/svg'
preserveAspectRatio='xMinYMin meet'
>
<path d='M255.008 158.086c-1.535-4.649-5.556-7.887-10.756-8.664-2.452-.366-5.26-.21-8.583.475-5.792 1.195-10.089 1.65-13.225 1.738 11.837-19.985 21.462-42.775 27.003-64.228 8.96-34.689 4.172-50.492-1.423-57.64C233.217 10.847 211.614.683 185.552.372c-13.903-.17-26.108 2.575-32.475 4.549-5.928-1.046-12.302-1.63-18.99-1.738-12.537-.2-23.614 2.533-33.079 8.15-5.24-1.772-13.65-4.27-23.362-5.864-22.842-3.75-41.252-.828-54.718 8.685C6.622 25.672-.937 45.684.461 73.634c.444 8.874 5.408 35.874 13.224 61.48 4.492 14.718 9.282 26.94 14.237 36.33 7.027 13.315 14.546 21.156 22.987 23.972 4.731 1.576 13.327 2.68 22.368-4.85 1.146 1.388 2.675 2.767 4.704 4.048 2.577 1.625 5.728 2.953 8.875 3.74 11.341 2.835 21.964 2.126 31.027-1.848.056 1.612.099 3.152.135 4.482.06 2.157.12 4.272.199 6.25.537 13.374 1.447 23.773 4.143 31.049.148.4.347 1.01.557 1.657 1.345 4.118 3.594 11.012 9.316 16.411 5.925 5.593 13.092 7.308 19.656 7.308 3.292 0 6.433-.432 9.188-1.022 9.82-2.105 20.973-5.311 29.041-16.799 7.628-10.86 11.336-27.217 12.007-52.99.087-.729.167-1.425.244-2.088l.16-1.362 1.797.158.463.031c10.002.456 22.232-1.665 29.743-5.154 5.935-2.754 24.954-12.795 20.476-26.351' />
<path
d='M237.906 160.722c-29.74 6.135-31.785-3.934-31.785-3.934 31.4-46.593 44.527-105.736 33.2-120.211-30.904-39.485-84.399-20.811-85.292-20.327l-.287.052c-5.876-1.22-12.451-1.946-19.842-2.067-13.456-.22-23.664 3.528-31.41 9.402 0 0-95.43-39.314-90.991 49.444.944 18.882 27.064 142.873 58.218 105.422 11.387-13.695 22.39-25.274 22.39-25.274 5.464 3.63 12.006 5.482 18.864 4.817l.533-.452c-.166 1.7-.09 3.363.213 5.332-8.026 8.967-5.667 10.541-21.711 13.844-16.235 3.346-6.698 9.302-.471 10.86 7.549 1.887 25.013 4.561 36.813-11.958l-.47 1.885c3.144 2.519 5.352 16.383 4.982 28.952-.37 12.568-.617 21.197 1.86 27.937 2.479 6.74 4.948 21.905 26.04 17.386 17.623-3.777 26.756-13.564 28.027-29.89.901-11.606 2.942-9.89 3.07-20.267l1.637-4.912c1.887-15.733.3-20.809 11.157-18.448l2.64.232c7.99.363 18.45-1.286 24.589-4.139 13.218-6.134 21.058-16.377 8.024-13.686h.002'
fill='#336791'
/>
<path
d='M108.076 81.525c-2.68-.373-5.107-.028-6.335.902-.69.523-.904 1.129-.962 1.546-.154 1.105.62 2.327 1.096 2.957 1.346 1.784 3.312 3.01 5.258 3.28.282.04.563.058.842.058 3.245 0 6.196-2.527 6.456-4.392.325-2.336-3.066-3.893-6.355-4.35M196.86 81.599c-.256-1.831-3.514-2.353-6.606-1.923-3.088.43-6.082 1.824-5.832 3.659.2 1.427 2.777 3.863 5.827 3.863.258 0 .518-.017.78-.054 2.036-.282 3.53-1.575 4.24-2.32 1.08-1.136 1.706-2.402 1.591-3.225'
fill='#FFF'
/>
<path
d='M247.802 160.025c-1.134-3.429-4.784-4.532-10.848-3.28-18.005 3.716-24.453 1.142-26.57-.417 13.995-21.32 25.508-47.092 31.719-71.137 2.942-11.39 4.567-21.968 4.7-30.59.147-9.463-1.465-16.417-4.789-20.665-13.402-17.125-33.072-26.311-56.882-26.563-16.369-.184-30.199 4.005-32.88 5.183-5.646-1.404-11.801-2.266-18.502-2.376-12.288-.199-22.91 2.743-31.704 8.74-3.82-1.422-13.692-4.811-25.765-6.756-20.872-3.36-37.458-.814-49.294 7.571-14.123 10.006-20.643 27.892-19.38 53.16.425 8.501 5.269 34.653 12.913 59.698 10.062 32.964 21 51.625 32.508 55.464 1.347.449 2.9.763 4.613.763 4.198 0 9.345-1.892 14.7-8.33a529.832 529.832 0 0 1 20.261-22.926c4.524 2.428 9.494 3.784 14.577 3.92.01.133.023.266.035.398a117.66 117.66 0 0 0-2.57 3.175c-3.522 4.471-4.255 5.402-15.592 7.736-3.225.666-11.79 2.431-11.916 8.435-.136 6.56 10.125 9.315 11.294 9.607 4.074 1.02 7.999 1.523 11.742 1.523 9.103 0 17.114-2.992 23.516-8.781-.197 23.386.778 46.43 3.586 53.451 2.3 5.748 7.918 19.795 25.664 19.794 2.604 0 5.47-.303 8.623-.979 18.521-3.97 26.564-12.156 29.675-30.203 1.665-9.645 4.522-32.676 5.866-45.03 2.836.885 6.487 1.29 10.434 1.289 8.232 0 17.731-1.749 23.688-4.514 6.692-3.108 18.768-10.734 16.578-17.36zm-44.106-83.48c-.061 3.647-.563 6.958-1.095 10.414-.573 3.717-1.165 7.56-1.314 12.225-.147 4.54.42 9.26.968 13.825 1.108 9.22 2.245 18.712-2.156 28.078a36.508 36.508 0 0 1-1.95-4.009c-.547-1.326-1.735-3.456-3.38-6.404-6.399-11.476-21.384-38.35-13.713-49.316 2.285-3.264 8.084-6.62 22.64-4.813zm-17.644-61.787c21.334.471 38.21 8.452 50.158 23.72 9.164 11.711-.927 64.998-30.14 110.969a171.33 171.33 0 0 0-.886-1.117l-.37-.462c7.549-12.467 6.073-24.802 4.759-35.738-.54-4.488-1.05-8.727-.92-12.709.134-4.22.692-7.84 1.232-11.34.663-4.313 1.338-8.776 1.152-14.037.139-.552.195-1.204.122-1.978-.475-5.045-6.235-20.144-17.975-33.81-6.422-7.475-15.787-15.84-28.574-21.482 5.5-1.14 13.021-2.203 21.442-2.016zM66.674 175.778c-5.9 7.094-9.974 5.734-11.314 5.288-8.73-2.912-18.86-21.364-27.791-50.624-7.728-25.318-12.244-50.777-12.602-57.916-1.128-22.578 4.345-38.313 16.268-46.769 19.404-13.76 51.306-5.524 64.125-1.347-.184.182-.376.352-.558.537-21.036 21.244-20.537 57.54-20.485 59.759-.002.856.07 2.068.168 3.735.362 6.105 1.036 17.467-.764 30.334-1.672 11.957 2.014 23.66 10.111 32.109a36.275 36.275 0 0 0 2.617 2.468c-3.604 3.86-11.437 12.396-19.775 22.426zm22.479-29.993c-6.526-6.81-9.49-16.282-8.133-25.99 1.9-13.592 1.199-25.43.822-31.79-.053-.89-.1-1.67-.127-2.285 3.073-2.725 17.314-10.355 27.47-8.028 4.634 1.061 7.458 4.217 8.632 9.645 6.076 28.103.804 39.816-3.432 49.229-.873 1.939-1.698 3.772-2.402 5.668l-.546 1.466c-1.382 3.706-2.668 7.152-3.465 10.424-6.938-.02-13.687-2.984-18.819-8.34zm1.065 37.9c-2.026-.506-3.848-1.385-4.917-2.114.893-.42 2.482-.992 5.238-1.56 13.337-2.745 15.397-4.683 19.895-10.394 1.031-1.31 2.2-2.794 3.819-4.602l.002-.002c2.411-2.7 3.514-2.242 5.514-1.412 1.621.67 3.2 2.702 3.84 4.938.303 1.056.643 3.06-.47 4.62-9.396 13.156-23.088 12.987-32.921 10.526zm69.799 64.952c-16.316 3.496-22.093-4.829-25.9-14.346-2.457-6.144-3.665-33.85-2.808-64.447.011-.407-.047-.8-.159-1.17a15.444 15.444 0 0 0-.456-2.162c-1.274-4.452-4.379-8.176-8.104-9.72-1.48-.613-4.196-1.738-7.46-.903.696-2.868 1.903-6.107 3.212-9.614l.549-1.475c.618-1.663 1.394-3.386 2.214-5.21 4.433-9.848 10.504-23.337 3.915-53.81-2.468-11.414-10.71-16.988-23.204-15.693-7.49.775-14.343 3.797-17.761 5.53-.735.372-1.407.732-2.035 1.082.954-11.5 4.558-32.992 18.04-46.59 8.489-8.56 19.794-12.788 33.568-12.56 27.14.444 44.544 14.372 54.366 25.979 8.464 10.001 13.047 20.076 14.876 25.51-13.755-1.399-23.11 1.316-27.852 8.096-10.317 14.748 5.644 43.372 13.315 57.129 1.407 2.521 2.621 4.7 3.003 5.626 2.498 6.054 5.732 10.096 8.093 13.046.724.904 1.426 1.781 1.96 2.547-4.166 1.201-11.649 3.976-10.967 17.847-.55 6.96-4.461 39.546-6.448 51.059-2.623 15.21-8.22 20.875-23.957 24.25zm68.104-77.936c-4.26 1.977-11.389 3.46-18.161 3.779-7.48.35-11.288-.838-12.184-1.569-.42-8.644 2.797-9.547 6.202-10.503.535-.15 1.057-.297 1.561-.473.313.255.656.508 1.032.756 6.012 3.968 16.735 4.396 31.874 1.271l.166-.033c-2.042 1.909-5.536 4.471-10.49 6.772z'
fill='#FFF'
/>
</svg>`}
/>
{/* MANUAL-CONTENT-START:intro */}
The [PostgreSQL](https://www.postgresql.org/) tool enables you to connect to any PostgreSQL database and perform a wide range of database operations directly within your agentic workflows. With secure connection handling and flexible configuration, you can easily manage and interact with your data.
With the PostgreSQL tool, you can:
- **Query data**: Execute SELECT queries to retrieve data from your PostgreSQL tables using the `postgresql_query` operation.
- **Insert records**: Add new rows to your tables with the `postgresql_insert` operation by specifying the table and data to insert.
- **Update records**: Modify existing data in your tables using the `postgresql_update` operation, providing the table, new data, and WHERE conditions.
- **Delete records**: Remove rows from your tables with the `postgresql_delete` operation, specifying the table and WHERE conditions.
- **Execute raw SQL**: Run any custom SQL command using the `postgresql_execute` operation for advanced use cases.
The PostgreSQL tool is ideal for scenarios where your agents need to interact with structured data—such as automating reporting, syncing data between systems, or powering data-driven workflows. It streamlines database access, making it easy to read, write, and manage your PostgreSQL data programmatically.
{/* MANUAL-CONTENT-END */}
## Usage Instructions
Connect to any PostgreSQL database to execute queries, manage data, and perform database operations. Supports SELECT, INSERT, UPDATE, DELETE operations with secure connection handling.
## Tools
### `postgresql_query`
Execute a SELECT query on PostgreSQL database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | PostgreSQL server hostname or IP address |
| `port` | number | Yes | PostgreSQL server port \(default: 5432\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | Yes | Database username |
| `password` | string | Yes | Database password |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `query` | string | Yes | SQL SELECT query to execute |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `rows` | array | Array of rows returned from the query |
| `rowCount` | number | Number of rows returned |
### `postgresql_insert`
Insert data into PostgreSQL database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | PostgreSQL server hostname or IP address |
| `port` | number | Yes | PostgreSQL server port \(default: 5432\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | Yes | Database username |
| `password` | string | Yes | Database password |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `table` | string | Yes | Table name to insert data into |
| `data` | object | Yes | Data object to insert \(key-value pairs\) |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `rows` | array | Inserted data \(if RETURNING clause used\) |
| `rowCount` | number | Number of rows inserted |
### `postgresql_update`
Update data in PostgreSQL database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | PostgreSQL server hostname or IP address |
| `port` | number | Yes | PostgreSQL server port \(default: 5432\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | Yes | Database username |
| `password` | string | Yes | Database password |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `table` | string | Yes | Table name to update data in |
| `data` | object | Yes | Data object with fields to update \(key-value pairs\) |
| `where` | string | Yes | WHERE clause condition \(without WHERE keyword\) |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `rows` | array | Updated data \(if RETURNING clause used\) |
| `rowCount` | number | Number of rows updated |
### `postgresql_delete`
Delete data from PostgreSQL database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | PostgreSQL server hostname or IP address |
| `port` | number | Yes | PostgreSQL server port \(default: 5432\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | Yes | Database username |
| `password` | string | Yes | Database password |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `table` | string | Yes | Table name to delete data from |
| `where` | string | Yes | WHERE clause condition \(without WHERE keyword\) |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `rows` | array | Deleted data \(if RETURNING clause used\) |
| `rowCount` | number | Number of rows deleted |
### `postgresql_execute`
Execute raw SQL query on PostgreSQL database
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | PostgreSQL server hostname or IP address |
| `port` | number | Yes | PostgreSQL server port \(default: 5432\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | Yes | Database username |
| `password` | string | Yes | Database password |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `query` | string | Yes | Raw SQL query to execute |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `rows` | array | Array of rows returned from the query |
| `rowCount` | number | Number of rows affected |
## Notes
- Category: `tools`
- Type: `postgresql`

View File

@@ -1,6 +1,9 @@
# Database (Required)
DATABASE_URL="postgresql://postgres:password@localhost:5432/postgres"
# PostgreSQL Port (Optional) - defaults to 5432 if not specified
# POSTGRES_PORT=5432
# Authentication (Required)
BETTER_AUTH_SECRET=your_secret_key # Use `openssl rand -hex 32` to generate, or visit https://www.better-auth.com/docs/installation
BETTER_AUTH_URL=http://localhost:3000

View File

@@ -49,15 +49,12 @@ const PASSWORD_VALIDATIONS = {
},
}
// Validate callback URL to prevent open redirect vulnerabilities
const validateCallbackUrl = (url: string): boolean => {
try {
// If it's a relative URL, it's safe
if (url.startsWith('/')) {
return true
}
// If absolute URL, check if it belongs to the same origin
const currentOrigin = typeof window !== 'undefined' ? window.location.origin : ''
if (url.startsWith(currentOrigin)) {
return true
@@ -70,7 +67,6 @@ const validateCallbackUrl = (url: string): boolean => {
}
}
// Validate password and return array of error messages
const validatePassword = (passwordValue: string): string[] => {
const errors: string[] = []
@@ -475,6 +471,23 @@ export default function LoginPage({
Sign up
</Link>
</div>
<div className='text-center text-neutral-500/80 text-xs leading-relaxed'>
By signing in, you agree to our{' '}
<Link
href='/terms'
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
>
Terms of Service
</Link>{' '}
and{' '}
<Link
href='/privacy'
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
>
Privacy Policy
</Link>
</div>
</div>
<Dialog open={forgotPasswordOpen} onOpenChange={setForgotPasswordOpen}>
@@ -504,9 +517,7 @@ export default function LoginPage({
</div>
{resetStatus.type && (
<div
className={`text-sm ${
resetStatus.type === 'success' ? 'text-[#4CAF50]' : 'text-red-500'
}`}
className={`text-sm ${resetStatus.type === 'success' ? 'text-[#4CAF50]' : 'text-red-500'}`}
>
{resetStatus.message}
</div>

View File

@@ -5,7 +5,7 @@
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { useRouter, useSearchParams } from 'next/navigation'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { client } from '@/lib/auth-client'
import { client, useSession } from '@/lib/auth-client'
import SignupPage from '@/app/(auth)/signup/signup-form'
vi.mock('next/navigation', () => ({
@@ -22,6 +22,7 @@ vi.mock('@/lib/auth-client', () => ({
sendVerificationOtp: vi.fn(),
},
},
useSession: vi.fn(),
}))
vi.mock('@/app/(auth)/components/social-login-buttons', () => ({
@@ -43,6 +44,9 @@ describe('SignupPage', () => {
vi.clearAllMocks()
;(useRouter as any).mockReturnValue(mockRouter)
;(useSearchParams as any).mockReturnValue(mockSearchParams)
;(useSession as any).mockReturnValue({
refetch: vi.fn().mockResolvedValue({}),
})
mockSearchParams.get.mockReturnValue(null)
})

View File

@@ -7,7 +7,7 @@ import { useRouter, useSearchParams } from 'next/navigation'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { client } from '@/lib/auth-client'
import { client, useSession } from '@/lib/auth-client'
import { quickValidateEmail } from '@/lib/email/validation'
import { createLogger } from '@/lib/logs/console/logger'
import { cn } from '@/lib/utils'
@@ -82,6 +82,7 @@ function SignupFormContent({
}) {
const router = useRouter()
const searchParams = useSearchParams()
const { refetch: refetchSession } = useSession()
const [isLoading, setIsLoading] = useState(false)
const [, setMounted] = useState(false)
const [showPassword, setShowPassword] = useState(false)
@@ -330,6 +331,15 @@ function SignupFormContent({
return
}
// Refresh session to get the new user data immediately after signup
try {
await refetchSession()
logger.info('Session refreshed after successful signup')
} catch (sessionError) {
logger.error('Failed to refresh session after signup:', sessionError)
// Continue anyway - the verification flow will handle this
}
// For new signups, always require verification
if (typeof window !== 'undefined') {
sessionStorage.setItem('verificationEmail', emailValue)
@@ -507,6 +517,23 @@ function SignupFormContent({
Sign in
</Link>
</div>
<div className='text-center text-neutral-500/80 text-xs leading-relaxed'>
By creating an account, you agree to our{' '}
<Link
href='/terms'
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
>
Terms of Service
</Link>{' '}
and{' '}
<Link
href='/privacy'
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
>
Privacy Policy
</Link>
</div>
</div>
</div>
)

View File

@@ -2,7 +2,7 @@
import { useEffect, useState } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
import { client } from '@/lib/auth-client'
import { client, useSession } from '@/lib/auth-client'
import { env, isTruthy } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
@@ -34,6 +34,7 @@ export function useVerification({
}: UseVerificationParams): UseVerificationReturn {
const router = useRouter()
const searchParams = useSearchParams()
const { refetch: refetchSession } = useSession()
const [otp, setOtp] = useState('')
const [email, setEmail] = useState('')
const [isLoading, setIsLoading] = useState(false)
@@ -136,16 +137,15 @@ export function useVerification({
}
}
// Redirect to proper page after a short delay
setTimeout(() => {
if (isInviteFlow && redirectUrl) {
// For invitation flow, redirect to the invitation page
router.push(redirectUrl)
window.location.href = redirectUrl
} else {
// Default redirect to dashboard
router.push('/workspace')
window.location.href = '/workspace'
}
}, 2000)
}, 1000)
} else {
logger.info('Setting invalid OTP state - API error response')
const message = 'Invalid verification code. Please check and try again.'
@@ -233,7 +233,7 @@ export function useVerification({
'requiresEmailVerification=; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT'
const timeoutId = setTimeout(() => {
router.push('/workspace')
window.location.href = '/workspace'
}, 1000)
return () => clearTimeout(timeoutId)

View File

@@ -0,0 +1,7 @@
import { toNextJsHandler } from 'better-auth/next-js'
import { auth } from '@/lib/auth'
export const dynamic = 'force-dynamic'
// Handle Stripe webhooks through better-auth
export const { GET, POST } = toNextJsHandler(auth.handler)

View File

@@ -1,109 +0,0 @@
import { type NextRequest, NextResponse } from 'next/server'
import { verifyCronAuth } from '@/lib/auth/internal'
import { processDailyBillingCheck } from '@/lib/billing/core/billing'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('DailyBillingCron')
/**
* Daily billing CRON job endpoint that checks individual billing periods
*/
export async function POST(request: NextRequest) {
try {
const authError = verifyCronAuth(request, 'daily billing check')
if (authError) {
return authError
}
logger.info('Starting daily billing check cron job')
const startTime = Date.now()
// Process overage billing for users and organizations with periods ending today
const result = await processDailyBillingCheck()
const duration = Date.now() - startTime
if (result.success) {
logger.info('Daily billing check completed successfully', {
processedUsers: result.processedUsers,
processedOrganizations: result.processedOrganizations,
totalChargedAmount: result.totalChargedAmount,
duration: `${duration}ms`,
})
return NextResponse.json({
success: true,
summary: {
processedUsers: result.processedUsers,
processedOrganizations: result.processedOrganizations,
totalChargedAmount: result.totalChargedAmount,
duration: `${duration}ms`,
},
})
}
logger.error('Daily billing check completed with errors', {
processedUsers: result.processedUsers,
processedOrganizations: result.processedOrganizations,
totalChargedAmount: result.totalChargedAmount,
errorCount: result.errors.length,
errors: result.errors,
duration: `${duration}ms`,
})
return NextResponse.json(
{
success: false,
summary: {
processedUsers: result.processedUsers,
processedOrganizations: result.processedOrganizations,
totalChargedAmount: result.totalChargedAmount,
errorCount: result.errors.length,
duration: `${duration}ms`,
},
errors: result.errors,
},
{ status: 500 }
)
} catch (error) {
logger.error('Fatal error in monthly billing cron job', { error })
return NextResponse.json(
{
success: false,
error: 'Internal server error during daily billing check',
details: error instanceof Error ? error.message : 'Unknown error',
},
{ status: 500 }
)
}
}
/**
* GET endpoint for manual testing and health checks
*/
export async function GET(request: NextRequest) {
try {
const authError = verifyCronAuth(request, 'daily billing check health check')
if (authError) {
return authError
}
return NextResponse.json({
status: 'ready',
message:
'Daily billing check cron job is ready to process users and organizations with periods ending today',
currentDate: new Date().toISOString().split('T')[0],
})
} catch (error) {
logger.error('Error in billing health check', { error })
return NextResponse.json(
{
status: 'error',
error: error instanceof Error ? error.message : 'Unknown error',
},
{ status: 500 }
)
}
}

View File

@@ -0,0 +1,77 @@
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { requireStripeClient } from '@/lib/billing/stripe-client'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { subscription as subscriptionTable, user } from '@/db/schema'
const logger = createLogger('BillingPortal')
export async function POST(request: NextRequest) {
const session = await getSession()
try {
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const body = await request.json().catch(() => ({}))
const context: 'user' | 'organization' =
body?.context === 'organization' ? 'organization' : 'user'
const organizationId: string | undefined = body?.organizationId || undefined
const returnUrl: string =
body?.returnUrl || `${env.NEXT_PUBLIC_APP_URL}/workspace?billing=updated`
const stripe = requireStripeClient()
let stripeCustomerId: string | null = null
if (context === 'organization') {
if (!organizationId) {
return NextResponse.json({ error: 'organizationId is required' }, { status: 400 })
}
const rows = await db
.select({ customer: subscriptionTable.stripeCustomerId })
.from(subscriptionTable)
.where(
and(
eq(subscriptionTable.referenceId, organizationId),
eq(subscriptionTable.status, 'active')
)
)
.limit(1)
stripeCustomerId = rows.length > 0 ? rows[0].customer || null : null
} else {
const rows = await db
.select({ customer: user.stripeCustomerId })
.from(user)
.where(eq(user.id, session.user.id))
.limit(1)
stripeCustomerId = rows.length > 0 ? rows[0].customer || null : null
}
if (!stripeCustomerId) {
logger.error('Stripe customer not found for portal session', {
context,
organizationId,
userId: session.user.id,
})
return NextResponse.json({ error: 'Stripe customer not found' }, { status: 404 })
}
const portal = await stripe.billingPortal.sessions.create({
customer: stripeCustomerId,
return_url: returnUrl,
})
return NextResponse.json({ url: portal.url })
} catch (error) {
logger.error('Failed to create billing portal session', { error })
return NextResponse.json({ error: 'Failed to create billing portal session' }, { status: 500 })
}
}

View File

@@ -5,7 +5,7 @@ import { getSimplifiedBillingSummary } from '@/lib/billing/core/billing'
import { getOrganizationBillingData } from '@/lib/billing/core/organization-billing'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { member } from '@/db/schema'
import { member, userStats } from '@/db/schema'
const logger = createLogger('UnifiedBillingAPI')
@@ -45,6 +45,16 @@ export async function GET(request: NextRequest) {
if (context === 'user') {
// Get user billing (may include organization if they're part of one)
billingData = await getSimplifiedBillingSummary(session.user.id, contextId || undefined)
// Attach billingBlocked status for the current user
const stats = await db
.select({ blocked: userStats.billingBlocked })
.from(userStats)
.where(eq(userStats.userId, session.user.id))
.limit(1)
billingData = {
...billingData,
billingBlocked: stats.length > 0 ? !!stats[0].blocked : false,
}
} else {
// Get user role in organization for permission checks first
const memberRecord = await db
@@ -78,8 +88,10 @@ export async function GET(request: NextRequest) {
subscriptionStatus: rawBillingData.subscriptionStatus,
totalSeats: rawBillingData.totalSeats,
usedSeats: rawBillingData.usedSeats,
seatsCount: rawBillingData.seatsCount,
totalCurrentUsage: rawBillingData.totalCurrentUsage,
totalUsageLimit: rawBillingData.totalUsageLimit,
minimumBillingAmount: rawBillingData.minimumBillingAmount,
averageUsagePerMember: rawBillingData.averageUsagePerMember,
billingPeriodStart: rawBillingData.billingPeriodStart?.toISOString() || null,
billingPeriodEnd: rawBillingData.billingPeriodEnd?.toISOString() || null,
@@ -92,11 +104,25 @@ export async function GET(request: NextRequest) {
const userRole = memberRecord[0].role
// Include the requesting user's blocked flag as well so UI can reflect it
const stats = await db
.select({ blocked: userStats.billingBlocked })
.from(userStats)
.where(eq(userStats.userId, session.user.id))
.limit(1)
// Merge blocked flag into data for convenience
billingData = {
...billingData,
billingBlocked: stats.length > 0 ? !!stats[0].blocked : false,
}
return NextResponse.json({
success: true,
context,
data: billingData,
userRole,
billingBlocked: billingData.billingBlocked,
})
}

View File

@@ -115,52 +115,34 @@ export async function POST(req: NextRequest) {
const userStatsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))
if (userStatsRecords.length === 0) {
// Create new user stats record (same logic as ExecutionLogger)
await db.insert(userStats).values({
id: crypto.randomUUID(),
userId: userId,
totalManualExecutions: 0,
totalApiCalls: 0,
totalWebhookTriggers: 0,
totalScheduledExecutions: 0,
totalChatExecutions: 0,
totalTokensUsed: totalTokens,
totalCost: costToStore.toString(),
currentPeriodCost: costToStore.toString(),
// Copilot usage tracking
totalCopilotCost: costToStore.toString(),
totalCopilotTokens: totalTokens,
totalCopilotCalls: 1,
lastActive: new Date(),
})
logger.info(`[${requestId}] Created new user stats record`, {
userId,
totalCost: costToStore,
totalTokens,
})
} else {
// Update existing user stats record (same logic as ExecutionLogger)
const updateFields = {
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
totalCost: sql`total_cost + ${costToStore}`,
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
// Copilot usage tracking increments
totalCopilotCost: sql`total_copilot_cost + ${costToStore}`,
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
totalCopilotCalls: sql`total_copilot_calls + 1`,
totalApiCalls: sql`total_api_calls`,
lastActive: new Date(),
}
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
logger.info(`[${requestId}] Updated user stats record`, {
userId,
addedCost: costToStore,
addedTokens: totalTokens,
})
logger.error(
`[${requestId}] User stats record not found - should be created during onboarding`,
{
userId,
}
)
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
}
// Update existing user stats record (same logic as ExecutionLogger)
const updateFields = {
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
totalCost: sql`total_cost + ${costToStore}`,
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
// Copilot usage tracking increments
totalCopilotCost: sql`total_copilot_cost + ${costToStore}`,
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
totalCopilotCalls: sql`total_copilot_calls + 1`,
totalApiCalls: sql`total_api_calls`,
lastActive: new Date(),
}
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
logger.info(`[${requestId}] Updated user stats record`, {
userId,
addedCost: costToStore,
addedTokens: totalTokens,
})
const duration = Date.now() - startTime

View File

@@ -1,116 +0,0 @@
import { headers } from 'next/headers'
import { type NextRequest, NextResponse } from 'next/server'
import type Stripe from 'stripe'
import { requireStripeClient } from '@/lib/billing/stripe-client'
import { handleInvoiceWebhook } from '@/lib/billing/webhooks/stripe-invoice-webhooks'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('StripeInvoiceWebhook')
/**
* Stripe billing webhook endpoint for invoice-related events
* Endpoint: /api/billing/webhooks/stripe
* Handles: invoice.payment_succeeded, invoice.payment_failed, invoice.finalized
*/
export async function POST(request: NextRequest) {
try {
const body = await request.text()
const headersList = await headers()
const signature = headersList.get('stripe-signature')
if (!signature) {
logger.error('Missing Stripe signature header')
return NextResponse.json({ error: 'Missing Stripe signature' }, { status: 400 })
}
if (!env.STRIPE_BILLING_WEBHOOK_SECRET) {
logger.error('Missing Stripe webhook secret configuration')
return NextResponse.json({ error: 'Webhook secret not configured' }, { status: 500 })
}
// Check if Stripe client is available
let stripe
try {
stripe = requireStripeClient()
} catch (stripeError) {
logger.error('Stripe client not available for webhook processing', {
error: stripeError,
})
return NextResponse.json({ error: 'Stripe client not configured' }, { status: 500 })
}
// Verify webhook signature
let event: Stripe.Event
try {
event = stripe.webhooks.constructEvent(body, signature, env.STRIPE_BILLING_WEBHOOK_SECRET)
} catch (signatureError) {
logger.error('Invalid Stripe webhook signature', {
error: signatureError,
signature,
})
return NextResponse.json({ error: 'Invalid signature' }, { status: 400 })
}
logger.info('Received Stripe invoice webhook', {
eventId: event.id,
eventType: event.type,
})
// Handle specific invoice events
const supportedEvents = [
'invoice.payment_succeeded',
'invoice.payment_failed',
'invoice.finalized',
]
if (supportedEvents.includes(event.type)) {
try {
await handleInvoiceWebhook(event)
logger.info('Successfully processed invoice webhook', {
eventId: event.id,
eventType: event.type,
})
return NextResponse.json({ received: true })
} catch (processingError) {
logger.error('Failed to process invoice webhook', {
eventId: event.id,
eventType: event.type,
error: processingError,
})
// Return 500 to tell Stripe to retry the webhook
return NextResponse.json({ error: 'Failed to process webhook' }, { status: 500 })
}
} else {
// Not a supported invoice event, ignore
logger.info('Ignoring unsupported webhook event', {
eventId: event.id,
eventType: event.type,
supportedEvents,
})
return NextResponse.json({ received: true })
}
} catch (error) {
logger.error('Fatal error in invoice webhook handler', {
error,
url: request.url,
})
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
}
}
/**
* GET endpoint for webhook health checks
*/
export async function GET() {
return NextResponse.json({
status: 'healthy',
webhook: 'stripe-invoices',
events: ['invoice.payment_succeeded', 'invoice.payment_failed', 'invoice.finalized'],
})
}

View File

@@ -45,6 +45,7 @@ export async function GET(request: Request) {
'support',
'admin',
'qa',
'agent',
]
if (reservedSubdomains.includes(subdomain)) {
return NextResponse.json(

View File

@@ -1,34 +1,12 @@
import { createCipheriv, createHash, createHmac, randomBytes } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { generateApiKey } from '@/lib/utils'
import { db } from '@/db'
import { copilotApiKeys } from '@/db/schema'
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
const logger = createLogger('CopilotApiKeysGenerate')
function deriveKey(keyString: string): Buffer {
return createHash('sha256').update(keyString, 'utf8').digest()
}
function encryptRandomIv(plaintext: string, keyString: string): string {
const key = deriveKey(keyString)
const iv = randomBytes(16)
const cipher = createCipheriv('aes-256-gcm', key, iv)
let encrypted = cipher.update(plaintext, 'utf8', 'hex')
encrypted += cipher.final('hex')
const authTag = cipher.getAuthTag().toString('hex')
return `${iv.toString('hex')}:${encrypted}:${authTag}`
}
function computeLookup(plaintext: string, keyString: string): string {
// Deterministic, constant-time comparable MAC: HMAC-SHA256(DB_KEY, plaintext)
return createHmac('sha256', Buffer.from(keyString, 'utf8'))
.update(plaintext, 'utf8')
.digest('hex')
}
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
export async function POST(req: NextRequest) {
try {
@@ -37,34 +15,39 @@ export async function POST(req: NextRequest) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
if (!env.AGENT_API_DB_ENCRYPTION_KEY) {
logger.error('AGENT_API_DB_ENCRYPTION_KEY is not set')
return NextResponse.json({ error: 'Server not configured' }, { status: 500 })
}
const userId = session.user.id
// Generate and prefix the key (strip the generic sim_ prefix from the random part)
const rawKey = generateApiKey().replace(/^sim_/, '')
const plaintextKey = `sk-sim-copilot-${rawKey}`
const res = await fetch(`${SIM_AGENT_API_URL}/api/validate-key/generate`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
},
body: JSON.stringify({ userId }),
})
// Encrypt with random IV for confidentiality
const dbEncrypted = encryptRandomIv(plaintextKey, env.AGENT_API_DB_ENCRYPTION_KEY)
if (!res.ok) {
const errorBody = await res.text().catch(() => '')
logger.error('Sim Agent generate key error', { status: res.status, error: errorBody })
return NextResponse.json(
{ error: 'Failed to generate copilot API key' },
{ status: res.status || 500 }
)
}
// Compute deterministic lookup value for O(1) search
const lookup = computeLookup(plaintextKey, env.AGENT_API_DB_ENCRYPTION_KEY)
const data = (await res.json().catch(() => null)) as { apiKey?: string } | null
const [inserted] = await db
.insert(copilotApiKeys)
.values({ userId, apiKeyEncrypted: dbEncrypted, apiKeyLookup: lookup })
.returning({ id: copilotApiKeys.id })
if (!data?.apiKey) {
logger.error('Sim Agent generate key returned invalid payload')
return NextResponse.json({ error: 'Invalid response from Sim Agent' }, { status: 500 })
}
return NextResponse.json(
{ success: true, key: { id: inserted.id, apiKey: plaintextKey } },
{ success: true, key: { id: 'new', apiKey: data.apiKey } },
{ status: 201 }
)
} catch (error) {
logger.error('Failed to generate copilot API key', { error })
logger.error('Failed to proxy generate copilot API key', { error })
return NextResponse.json({ error: 'Failed to generate copilot API key' }, { status: 500 })
}
}

View File

@@ -1,32 +1,12 @@
import { createDecipheriv, createHash } from 'crypto'
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { copilotApiKeys } from '@/db/schema'
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
const logger = createLogger('CopilotApiKeys')
function deriveKey(keyString: string): Buffer {
return createHash('sha256').update(keyString, 'utf8').digest()
}
function decryptWithKey(encryptedValue: string, keyString: string): string {
const parts = encryptedValue.split(':')
if (parts.length !== 3) {
throw new Error('Invalid encrypted value format')
}
const [ivHex, encryptedHex, authTagHex] = parts
const key = deriveKey(keyString)
const iv = Buffer.from(ivHex, 'hex')
const decipher = createDecipheriv('aes-256-gcm', key, iv)
decipher.setAuthTag(Buffer.from(authTagHex, 'hex'))
let decrypted = decipher.update(encryptedHex, 'hex', 'utf8')
decrypted += decipher.final('utf8')
return decrypted
}
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
export async function GET(request: NextRequest) {
try {
@@ -35,22 +15,31 @@ export async function GET(request: NextRequest) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
if (!env.AGENT_API_DB_ENCRYPTION_KEY) {
logger.error('AGENT_API_DB_ENCRYPTION_KEY is not set')
return NextResponse.json({ error: 'Server not configured' }, { status: 500 })
}
const userId = session.user.id
const rows = await db
.select({ id: copilotApiKeys.id, apiKeyEncrypted: copilotApiKeys.apiKeyEncrypted })
.from(copilotApiKeys)
.where(eq(copilotApiKeys.userId, userId))
const res = await fetch(`${SIM_AGENT_API_URL}/api/validate-key/get-api-keys`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
},
body: JSON.stringify({ userId }),
})
const keys = rows.map((row) => ({
id: row.id,
apiKey: decryptWithKey(row.apiKeyEncrypted, env.AGENT_API_DB_ENCRYPTION_KEY as string),
}))
if (!res.ok) {
const errorBody = await res.text().catch(() => '')
logger.error('Sim Agent get-api-keys error', { status: res.status, error: errorBody })
return NextResponse.json({ error: 'Failed to get keys' }, { status: res.status || 500 })
}
const apiKeys = (await res.json().catch(() => null)) as { id: string; apiKey: string }[] | null
if (!Array.isArray(apiKeys)) {
logger.error('Sim Agent get-api-keys returned invalid payload')
return NextResponse.json({ error: 'Invalid response from Sim Agent' }, { status: 500 })
}
const keys = apiKeys
return NextResponse.json({ keys }, { status: 200 })
} catch (error) {
@@ -73,9 +62,26 @@ export async function DELETE(request: NextRequest) {
return NextResponse.json({ error: 'id is required' }, { status: 400 })
}
await db
.delete(copilotApiKeys)
.where(and(eq(copilotApiKeys.userId, userId), eq(copilotApiKeys.id, id)))
const res = await fetch(`${SIM_AGENT_API_URL}/api/validate-key/delete`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
},
body: JSON.stringify({ userId, apiKeyId: id }),
})
if (!res.ok) {
const errorBody = await res.text().catch(() => '')
logger.error('Sim Agent delete key error', { status: res.status, error: errorBody })
return NextResponse.json({ error: 'Failed to delete key' }, { status: res.status || 500 })
}
const data = (await res.json().catch(() => null)) as { success?: boolean } | null
if (!data?.success) {
logger.error('Sim Agent delete key returned invalid payload')
return NextResponse.json({ error: 'Invalid response from Sim Agent' }, { status: 500 })
}
return NextResponse.json({ success: true }, { status: 200 })
} catch (error) {

View File

@@ -1,50 +1,29 @@
import { createHmac } from 'crypto'
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { env } from '@/lib/env'
import { checkInternalApiKey } from '@/lib/copilot/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { copilotApiKeys, userStats } from '@/db/schema'
import { userStats } from '@/db/schema'
const logger = createLogger('CopilotApiKeysValidate')
function computeLookup(plaintext: string, keyString: string): string {
// Deterministic MAC: HMAC-SHA256(DB_KEY, plaintext)
return createHmac('sha256', Buffer.from(keyString, 'utf8'))
.update(plaintext, 'utf8')
.digest('hex')
}
export async function POST(req: NextRequest) {
try {
if (!env.AGENT_API_DB_ENCRYPTION_KEY) {
logger.error('AGENT_API_DB_ENCRYPTION_KEY is not set')
return NextResponse.json({ error: 'Server not configured' }, { status: 500 })
// Authenticate via internal API key header
const auth = checkInternalApiKey(req)
if (!auth.success) {
return new NextResponse(null, { status: 401 })
}
const body = await req.json().catch(() => null)
const apiKey = typeof body?.apiKey === 'string' ? body.apiKey : undefined
const userId = typeof body?.userId === 'string' ? body.userId : undefined
if (!apiKey) {
return new NextResponse(null, { status: 401 })
if (!userId) {
return NextResponse.json({ error: 'userId is required' }, { status: 400 })
}
const lookup = computeLookup(apiKey, env.AGENT_API_DB_ENCRYPTION_KEY)
logger.info('[API VALIDATION] Validating usage limit', { userId })
// Find matching API key and its user
const rows = await db
.select({ id: copilotApiKeys.id, userId: copilotApiKeys.userId })
.from(copilotApiKeys)
.where(eq(copilotApiKeys.apiKeyLookup, lookup))
.limit(1)
if (rows.length === 0) {
return new NextResponse(null, { status: 401 })
}
const { userId } = rows[0]
// Check usage for the associated user
const usage = await db
.select({
currentPeriodCost: userStats.currentPeriodCost,
@@ -55,6 +34,8 @@ export async function POST(req: NextRequest) {
.where(eq(userStats.userId, userId))
.limit(1)
logger.info('[API VALIDATION] Usage limit validated', { userId, usage })
if (usage.length > 0) {
const currentUsage = Number.parseFloat(
(usage[0].currentPeriodCost?.toString() as string) ||
@@ -64,16 +45,14 @@ export async function POST(req: NextRequest) {
const limit = Number.parseFloat((usage[0].currentUsageLimit as unknown as string) || '0')
if (!Number.isNaN(limit) && limit > 0 && currentUsage >= limit) {
// Usage exceeded
logger.info('[API VALIDATION] Usage exceeded', { userId, currentUsage, limit })
return new NextResponse(null, { status: 402 })
}
}
// Valid and within usage limits
return new NextResponse(null, { status: 200 })
} catch (error) {
logger.error('Error validating copilot API key', { error })
return NextResponse.json({ error: 'Failed to validate key' }, { status: 500 })
logger.error('Error validating usage limit', { error })
return NextResponse.json({ error: 'Failed to validate usage' }, { status: 500 })
}
}

View File

@@ -224,9 +224,7 @@ describe('Copilot Chat API Route', () => {
stream: true,
streamToolCalls: true,
mode: 'agent',
provider: 'openai',
depth: 0,
origin: 'http://localhost:3000',
}),
})
)
@@ -288,9 +286,7 @@ describe('Copilot Chat API Route', () => {
stream: true,
streamToolCalls: true,
mode: 'agent',
provider: 'openai',
depth: 0,
origin: 'http://localhost:3000',
}),
})
)
@@ -300,7 +296,6 @@ describe('Copilot Chat API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock new chat creation
const newChat = {
id: 'chat-123',
userId: 'user-123',
@@ -309,8 +304,6 @@ describe('Copilot Chat API Route', () => {
}
mockReturning.mockResolvedValue([newChat])
// Mock sim agent response
;(global.fetch as any).mockResolvedValue({
ok: true,
body: new ReadableStream({
@@ -344,9 +337,7 @@ describe('Copilot Chat API Route', () => {
stream: true,
streamToolCalls: true,
mode: 'agent',
provider: 'openai',
depth: 0,
origin: 'http://localhost:3000',
}),
})
)
@@ -356,11 +347,8 @@ describe('Copilot Chat API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock new chat creation
mockReturning.mockResolvedValue([{ id: 'chat-123', messages: [] }])
// Mock sim agent error
;(global.fetch as any).mockResolvedValue({
ok: false,
status: 500,
@@ -406,11 +394,8 @@ describe('Copilot Chat API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock new chat creation
mockReturning.mockResolvedValue([{ id: 'chat-123', messages: [] }])
// Mock sim agent response
;(global.fetch as any).mockResolvedValue({
ok: true,
body: new ReadableStream({
@@ -440,9 +425,7 @@ describe('Copilot Chat API Route', () => {
stream: true,
streamToolCalls: true,
mode: 'ask',
provider: 'openai',
depth: 0,
origin: 'http://localhost:3000',
}),
})
)

View File

@@ -1,4 +1,3 @@
import { createCipheriv, createDecipheriv, createHash, randomBytes } from 'crypto'
import { and, desc, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
@@ -11,77 +10,36 @@ import {
createUnauthorizedResponse,
} from '@/lib/copilot/auth'
import { getCopilotModel } from '@/lib/copilot/config'
import { TITLE_GENERATION_SYSTEM_PROMPT, TITLE_GENERATION_USER_PROMPT } from '@/lib/copilot/prompts'
import type { CopilotProviderConfig } from '@/lib/copilot/types'
import { env } from '@/lib/env'
import { generateChatTitle } from '@/lib/generate-chat-title'
import { createLogger } from '@/lib/logs/console/logger'
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
import { downloadFile } from '@/lib/uploads'
import { downloadFromS3WithConfig } from '@/lib/uploads/s3/s3-client'
import { S3_COPILOT_CONFIG, USE_S3_STORAGE } from '@/lib/uploads/setup'
import { createFileContent, isSupportedFileType } from '@/lib/uploads/file-utils'
import { S3_COPILOT_CONFIG } from '@/lib/uploads/setup'
import { downloadFile, getStorageProvider } from '@/lib/uploads/storage-client'
import { db } from '@/db'
import { copilotChats } from '@/db/schema'
import { executeProviderRequest } from '@/providers'
import { createAnthropicFileContent, isSupportedFileType } from './file-utils'
const logger = createLogger('CopilotChatAPI')
// Sim Agent API configuration
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
function getRequestOrigin(_req: NextRequest): string {
try {
// Strictly use configured Better Auth URL
return env.BETTER_AUTH_URL || ''
} catch (_) {
return ''
}
}
function deriveKey(keyString: string): Buffer {
return createHash('sha256').update(keyString, 'utf8').digest()
}
function decryptWithKey(encryptedValue: string, keyString: string): string {
const [ivHex, encryptedHex, authTagHex] = encryptedValue.split(':')
if (!ivHex || !encryptedHex || !authTagHex) {
throw new Error('Invalid encrypted format')
}
const key = deriveKey(keyString)
const iv = Buffer.from(ivHex, 'hex')
const decipher = createDecipheriv('aes-256-gcm', key, iv)
decipher.setAuthTag(Buffer.from(authTagHex, 'hex'))
let decrypted = decipher.update(encryptedHex, 'hex', 'utf8')
decrypted += decipher.final('utf8')
return decrypted
}
function encryptWithKey(plaintext: string, keyString: string): string {
const key = deriveKey(keyString)
const iv = randomBytes(16)
const cipher = createCipheriv('aes-256-gcm', key, iv)
let encrypted = cipher.update(plaintext, 'utf8', 'hex')
encrypted += cipher.final('hex')
const authTag = cipher.getAuthTag().toString('hex')
return `${iv.toString('hex')}:${encrypted}:${authTag}`
}
// Schema for file attachments
const FileAttachmentSchema = z.object({
id: z.string(),
s3_key: z.string(),
key: z.string(),
filename: z.string(),
media_type: z.string(),
size: z.number(),
})
// Schema for chat messages
const ChatMessageSchema = z.object({
message: z.string().min(1, 'Message is required'),
userMessageId: z.string().optional(), // ID from frontend for the user message
chatId: z.string().optional(),
workflowId: z.string().min(1, 'Workflow ID is required'),
mode: z.enum(['ask', 'agent']).optional().default('agent'),
depth: z.number().int().min(-2).max(3).optional().default(0),
depth: z.number().int().min(0).max(3).optional().default(0),
prefetch: z.boolean().optional(),
createNewChat: z.boolean().optional().default(false),
stream: z.boolean().optional().default(true),
@@ -89,90 +47,20 @@ const ChatMessageSchema = z.object({
fileAttachments: z.array(FileAttachmentSchema).optional(),
provider: z.string().optional().default('openai'),
conversationId: z.string().optional(),
})
/**
* Generate a chat title using LLM
*/
async function generateChatTitle(userMessage: string): Promise<string> {
try {
const { provider, model } = getCopilotModel('title')
// Get the appropriate API key for the provider
let apiKey: string | undefined
if (provider === 'anthropic') {
// Use rotating API key for Anthropic
const { getRotatingApiKey } = require('@/lib/utils')
try {
apiKey = getRotatingApiKey('anthropic')
logger.debug(`Using rotating API key for Anthropic title generation`)
} catch (e) {
// If rotation fails, let the provider handle it
logger.warn(`Failed to get rotating API key for Anthropic:`, e)
}
}
const response = await executeProviderRequest(provider, {
model,
systemPrompt: TITLE_GENERATION_SYSTEM_PROMPT,
context: TITLE_GENERATION_USER_PROMPT(userMessage),
temperature: 0.3,
maxTokens: 50,
apiKey: apiKey || '',
stream: false,
})
if (typeof response === 'object' && 'content' in response) {
return response.content?.trim() || 'New Chat'
}
return 'New Chat'
} catch (error) {
logger.error('Failed to generate chat title:', error)
return 'New Chat'
}
}
/**
* Generate chat title asynchronously and update the database
*/
async function generateChatTitleAsync(
chatId: string,
userMessage: string,
requestId: string,
streamController?: ReadableStreamDefaultController<Uint8Array>
): Promise<void> {
try {
logger.info(`[${requestId}] Starting async title generation for chat ${chatId}`)
const title = await generateChatTitle(userMessage)
// Update the chat with the generated title
await db
.update(copilotChats)
.set({
title,
updatedAt: new Date(),
contexts: z
.array(
z.object({
kind: z.enum(['past_chat', 'workflow', 'blocks', 'logs', 'knowledge', 'templates']),
label: z.string(),
chatId: z.string().optional(),
workflowId: z.string().optional(),
knowledgeId: z.string().optional(),
blockId: z.string().optional(),
templateId: z.string().optional(),
})
.where(eq(copilotChats.id, chatId))
// Send title_updated event to client if streaming
if (streamController) {
const encoder = new TextEncoder()
const titleEvent = `data: ${JSON.stringify({
type: 'title_updated',
title: title,
})}\n\n`
streamController.enqueue(encoder.encode(titleEvent))
logger.debug(`[${requestId}] Sent title_updated event to client: "${title}"`)
}
logger.info(`[${requestId}] Generated title for chat ${chatId}: "${title}"`)
} catch (error) {
logger.error(`[${requestId}] Failed to generate title for chat ${chatId}:`, error)
// Don't throw - this is a background operation
}
}
)
.optional(),
})
/**
* POST /api/copilot/chat
@@ -206,14 +94,37 @@ export async function POST(req: NextRequest) {
fileAttachments,
provider,
conversationId,
contexts,
} = ChatMessageSchema.parse(body)
// Derive request origin for downstream service
const requestOrigin = getRequestOrigin(req)
if (!requestOrigin) {
logger.error(`[${tracker.requestId}] Missing required configuration: BETTER_AUTH_URL`)
return createInternalServerErrorResponse('Missing required configuration: BETTER_AUTH_URL')
try {
logger.info(`[${tracker.requestId}] Received chat POST`, {
hasContexts: Array.isArray(contexts),
contextsCount: Array.isArray(contexts) ? contexts.length : 0,
contextsPreview: Array.isArray(contexts)
? contexts.map((c: any) => ({
kind: c?.kind,
chatId: c?.chatId,
workflowId: c?.workflowId,
label: c?.label,
}))
: undefined,
})
} catch {}
// Preprocess contexts server-side
let agentContexts: Array<{ type: string; content: string }> = []
if (Array.isArray(contexts) && contexts.length > 0) {
try {
const { processContextsServer } = await import('@/lib/copilot/process-contents')
const processed = await processContextsServer(contexts as any, authenticatedUserId)
agentContexts = processed
logger.info(`[${tracker.requestId}] Contexts processed for request`, {
processedCount: agentContexts.length,
kinds: agentContexts.map((c) => c.type),
lengthPreview: agentContexts.map((c) => c.content?.length ?? 0),
})
} catch (e) {
logger.error(`[${tracker.requestId}] Failed to process contexts`, e)
}
}
// Consolidation mapping: map negative depths to base depth with prefetch=true
@@ -229,22 +140,6 @@ export async function POST(req: NextRequest) {
}
}
logger.info(`[${tracker.requestId}] Processing copilot chat request`, {
userId: authenticatedUserId,
workflowId,
chatId,
mode,
stream,
createNewChat,
messageLength: message.length,
hasImplicitFeedback: !!implicitFeedback,
provider: provider || 'openai',
hasConversationId: !!conversationId,
depth,
prefetch,
origin: requestOrigin,
})
// Handle chat context
let currentChat: any = null
let conversationHistory: any[] = []
@@ -285,8 +180,6 @@ export async function POST(req: NextRequest) {
// Process file attachments if present
const processedFileContents: any[] = []
if (fileAttachments && fileAttachments.length > 0) {
logger.info(`[${tracker.requestId}] Processing ${fileAttachments.length} file attachments`)
for (const attachment of fileAttachments) {
try {
// Check if file type is supported
@@ -295,23 +188,30 @@ export async function POST(req: NextRequest) {
continue
}
// Download file from S3
logger.info(`[${tracker.requestId}] Downloading file: ${attachment.s3_key}`)
const storageProvider = getStorageProvider()
let fileBuffer: Buffer
if (USE_S3_STORAGE) {
fileBuffer = await downloadFromS3WithConfig(attachment.s3_key, S3_COPILOT_CONFIG)
if (storageProvider === 's3') {
fileBuffer = await downloadFile(attachment.key, {
bucket: S3_COPILOT_CONFIG.bucket,
region: S3_COPILOT_CONFIG.region,
})
} else if (storageProvider === 'blob') {
const { BLOB_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
fileBuffer = await downloadFile(attachment.key, {
containerName: BLOB_COPILOT_CONFIG.containerName,
accountName: BLOB_COPILOT_CONFIG.accountName,
accountKey: BLOB_COPILOT_CONFIG.accountKey,
connectionString: BLOB_COPILOT_CONFIG.connectionString,
})
} else {
// Fallback to generic downloadFile for other storage providers
fileBuffer = await downloadFile(attachment.s3_key)
fileBuffer = await downloadFile(attachment.key)
}
// Convert to Anthropic format
const fileContent = createAnthropicFileContent(fileBuffer, attachment.media_type)
// Convert to format
const fileContent = createFileContent(fileBuffer, attachment.media_type)
if (fileContent) {
processedFileContents.push(fileContent)
logger.info(
`[${tracker.requestId}] Processed file: ${attachment.filename} (${attachment.media_type})`
)
}
} catch (error) {
logger.error(
@@ -336,14 +236,26 @@ export async function POST(req: NextRequest) {
for (const attachment of msg.fileAttachments) {
try {
if (isSupportedFileType(attachment.media_type)) {
const storageProvider = getStorageProvider()
let fileBuffer: Buffer
if (USE_S3_STORAGE) {
fileBuffer = await downloadFromS3WithConfig(attachment.s3_key, S3_COPILOT_CONFIG)
if (storageProvider === 's3') {
fileBuffer = await downloadFile(attachment.key, {
bucket: S3_COPILOT_CONFIG.bucket,
region: S3_COPILOT_CONFIG.region,
})
} else if (storageProvider === 'blob') {
const { BLOB_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
fileBuffer = await downloadFile(attachment.key, {
containerName: BLOB_COPILOT_CONFIG.containerName,
accountName: BLOB_COPILOT_CONFIG.accountName,
accountKey: BLOB_COPILOT_CONFIG.accountKey,
connectionString: BLOB_COPILOT_CONFIG.connectionString,
})
} else {
// Fallback to generic downloadFile for other storage providers
fileBuffer = await downloadFile(attachment.s3_key)
fileBuffer = await downloadFile(attachment.key)
}
const fileContent = createAnthropicFileContent(fileBuffer, attachment.media_type)
const fileContent = createFileContent(fileBuffer, attachment.media_type)
if (fileContent) {
content.push(fileContent)
}
@@ -399,8 +311,31 @@ export async function POST(req: NextRequest) {
})
}
const defaults = getCopilotModel('chat')
const modelToUse = env.COPILOT_MODEL || defaults.model
let providerConfig: CopilotProviderConfig | undefined
const providerEnv = env.COPILOT_PROVIDER as any
if (providerEnv) {
if (providerEnv === 'azure-openai') {
providerConfig = {
provider: 'azure-openai',
model: modelToUse,
apiKey: env.AZURE_OPENAI_API_KEY,
apiVersion: 'preview',
endpoint: env.AZURE_OPENAI_ENDPOINT,
}
} else {
providerConfig = {
provider: providerEnv,
model: modelToUse,
apiKey: env.COPILOT_API_KEY,
}
}
}
// Determine provider and conversationId to use for this request
const providerToUse = provider || 'openai'
const effectiveConversationId =
(currentChat?.conversationId as string | undefined) || conversationId
@@ -416,35 +351,19 @@ export async function POST(req: NextRequest) {
stream: stream,
streamToolCalls: true,
mode: mode,
provider: providerToUse,
...(providerConfig ? { provider: providerConfig } : {}),
...(effectiveConversationId ? { conversationId: effectiveConversationId } : {}),
...(typeof effectiveDepth === 'number' ? { depth: effectiveDepth } : {}),
...(typeof effectivePrefetch === 'boolean' ? { prefetch: effectivePrefetch } : {}),
...(session?.user?.name && { userName: session.user.name }),
...(requestOrigin ? { origin: requestOrigin } : {}),
...(agentContexts.length > 0 && { context: agentContexts }),
}
// Log the payload being sent to the streaming endpoint
try {
logger.info(`[${tracker.requestId}] Sending payload to sim agent streaming endpoint`, {
url: `${SIM_AGENT_API_URL}/api/chat-completion-streaming`,
provider: providerToUse,
mode,
stream,
workflowId,
hasConversationId: !!effectiveConversationId,
depth: typeof effectiveDepth === 'number' ? effectiveDepth : undefined,
prefetch: typeof effectivePrefetch === 'boolean' ? effectivePrefetch : undefined,
messagesCount: requestPayload.messages.length,
...(requestOrigin ? { origin: requestOrigin } : {}),
logger.info(`[${tracker.requestId}] About to call Sim Agent with context`, {
context: (requestPayload as any).context,
})
// Full payload as JSON string
logger.info(
`[${tracker.requestId}] Full streaming payload: ${JSON.stringify(requestPayload)}`
)
} catch (e) {
logger.warn(`[${tracker.requestId}] Failed to log payload preview for streaming endpoint`, e)
}
} catch {}
const simAgentResponse = await fetch(`${SIM_AGENT_API_URL}/api/chat-completion-streaming`, {
method: 'POST',
@@ -475,8 +394,6 @@ export async function POST(req: NextRequest) {
// If streaming is requested, forward the stream and update chat later
if (stream && simAgentResponse.body) {
logger.info(`[${tracker.requestId}] Streaming response from sim agent`)
// Create user message to save
const userMessage = {
id: userMessageId || crypto.randomUUID(), // Use frontend ID if provided
@@ -484,6 +401,11 @@ export async function POST(req: NextRequest) {
content: message,
timestamp: new Date().toISOString(),
...(fileAttachments && fileAttachments.length > 0 && { fileAttachments }),
...(Array.isArray(contexts) && contexts.length > 0 && { contexts }),
...(Array.isArray(contexts) &&
contexts.length > 0 && {
contentBlocks: [{ type: 'contexts', contexts: contexts as any, timestamp: Date.now() }],
}),
}
// Create a pass-through stream that captures the response
@@ -493,7 +415,7 @@ export async function POST(req: NextRequest) {
let assistantContent = ''
const toolCalls: any[] = []
let buffer = ''
let isFirstDone = true
const isFirstDone = true
let responseIdFromStart: string | undefined
let responseIdFromDone: string | undefined
// Track tool call progress to identify a safe done event
@@ -515,30 +437,30 @@ export async function POST(req: NextRequest) {
// Start title generation in parallel if needed
if (actualChatId && !currentChat?.title && conversationHistory.length === 0) {
logger.info(`[${tracker.requestId}] Starting title generation with stream updates`, {
chatId: actualChatId,
hasTitle: !!currentChat?.title,
conversationLength: conversationHistory.length,
message: message.substring(0, 100) + (message.length > 100 ? '...' : ''),
})
generateChatTitleAsync(actualChatId, message, tracker.requestId, controller).catch(
(error) => {
generateChatTitle(message)
.then(async (title) => {
if (title) {
await db
.update(copilotChats)
.set({
title,
updatedAt: new Date(),
})
.where(eq(copilotChats.id, actualChatId!))
const titleEvent = `data: ${JSON.stringify({
type: 'title_updated',
title: title,
})}\n\n`
controller.enqueue(encoder.encode(titleEvent))
logger.info(`[${tracker.requestId}] Generated and saved title: ${title}`)
}
})
.catch((error) => {
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
}
)
})
} else {
logger.debug(`[${tracker.requestId}] Skipping title generation`, {
chatId: actualChatId,
hasTitle: !!currentChat?.title,
conversationLength: conversationHistory.length,
reason: !actualChatId
? 'no chatId'
: currentChat?.title
? 'already has title'
: conversationHistory.length > 0
? 'not first message'
: 'unknown',
})
logger.debug(`[${tracker.requestId}] Skipping title generation`)
}
// Forward the sim agent stream and capture assistant response
@@ -549,7 +471,6 @@ export async function POST(req: NextRequest) {
while (true) {
const { done, value } = await reader.read()
if (done) {
logger.info(`[${tracker.requestId}] Stream reading completed`)
break
}
@@ -559,13 +480,9 @@ export async function POST(req: NextRequest) {
controller.enqueue(value)
} catch (error) {
// Client disconnected - stop reading from sim agent
logger.info(
`[${tracker.requestId}] Client disconnected, stopping stream processing`
)
reader.cancel() // Stop reading from sim agent
break
}
const chunkSize = value.byteLength
// Decode and parse SSE events for logging and capturing content
const decodedChunk = decoder.decode(value, { stream: true })
@@ -601,22 +518,12 @@ export async function POST(req: NextRequest) {
break
case 'reasoning':
// Treat like thinking: do not add to assistantContent to avoid leaking
logger.debug(
`[${tracker.requestId}] Reasoning chunk received (${(event.data || event.content || '').length} chars)`
)
break
case 'tool_call':
logger.info(
`[${tracker.requestId}] Tool call ${event.data?.partial ? '(partial)' : '(complete)'}:`,
{
id: event.data?.id,
name: event.data?.name,
arguments: event.data?.arguments,
blockIndex: event.data?._blockIndex,
}
)
if (!event.data?.partial) {
toolCalls.push(event.data)
if (event.data?.id) {
@@ -625,30 +532,13 @@ export async function POST(req: NextRequest) {
}
break
case 'tool_execution':
logger.info(`[${tracker.requestId}] Tool execution started:`, {
toolCallId: event.toolCallId,
toolName: event.toolName,
status: event.status,
})
case 'tool_generating':
if (event.toolCallId) {
if (event.status === 'completed') {
startedToolExecutionIds.add(event.toolCallId)
completedToolExecutionIds.add(event.toolCallId)
} else {
startedToolExecutionIds.add(event.toolCallId)
}
startedToolExecutionIds.add(event.toolCallId)
}
break
case 'tool_result':
logger.info(`[${tracker.requestId}] Tool result received:`, {
toolCallId: event.toolCallId,
toolName: event.toolName,
success: event.success,
result: `${JSON.stringify(event.result).substring(0, 200)}...`,
resultSize: JSON.stringify(event.result).length,
})
if (event.toolCallId) {
completedToolExecutionIds.add(event.toolCallId)
}
@@ -669,9 +559,6 @@ export async function POST(req: NextRequest) {
case 'start':
if (event.data?.responseId) {
responseIdFromStart = event.data.responseId
logger.info(
`[${tracker.requestId}] Received start event with responseId: ${responseIdFromStart}`
)
}
break
@@ -679,9 +566,7 @@ export async function POST(req: NextRequest) {
if (event.data?.responseId) {
responseIdFromDone = event.data.responseId
lastDoneResponseId = responseIdFromDone
logger.info(
`[${tracker.requestId}] Received done event with responseId: ${responseIdFromDone}`
)
// Mark this done as safe only if no tool call is currently in progress or pending
const announced = announcedToolCallIds.size
const completed = completedToolExecutionIds.size
@@ -689,34 +574,14 @@ export async function POST(req: NextRequest) {
const hasToolInProgress = announced > completed || started > completed
if (!hasToolInProgress) {
lastSafeDoneResponseId = responseIdFromDone
logger.info(
`[${tracker.requestId}] Marked done as SAFE (no tools in progress)`
)
} else {
logger.info(
`[${tracker.requestId}] Done received but tools are in progress (announced=${announced}, started=${started}, completed=${completed})`
)
}
}
if (isFirstDone) {
logger.info(
`[${tracker.requestId}] Initial AI response complete, tool count: ${toolCalls.length}`
)
isFirstDone = false
} else {
logger.info(`[${tracker.requestId}] Conversation round complete`)
}
break
case 'error':
logger.error(`[${tracker.requestId}] Stream error event:`, event.error)
break
default:
logger.debug(
`[${tracker.requestId}] Unknown event type: ${event.type}`,
event
)
}
} catch (e) {
// Enhanced error handling for large payloads and parsing issues
@@ -874,6 +739,11 @@ export async function POST(req: NextRequest) {
content: message,
timestamp: new Date().toISOString(),
...(fileAttachments && fileAttachments.length > 0 && { fileAttachments }),
...(Array.isArray(contexts) && contexts.length > 0 && { contexts }),
...(Array.isArray(contexts) &&
contexts.length > 0 && {
contentBlocks: [{ type: 'contexts', contexts: contexts as any, timestamp: Date.now() }],
}),
}
const assistantMessage = {
@@ -888,9 +758,22 @@ export async function POST(req: NextRequest) {
// Start title generation in parallel if this is first message (non-streaming)
if (actualChatId && !currentChat.title && conversationHistory.length === 0) {
logger.info(`[${tracker.requestId}] Starting title generation for non-streaming response`)
generateChatTitleAsync(actualChatId, message, tracker.requestId).catch((error) => {
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
})
generateChatTitle(message)
.then(async (title) => {
if (title) {
await db
.update(copilotChats)
.set({
title,
updatedAt: new Date(),
})
.where(eq(copilotChats.id, actualChatId!))
logger.info(`[${tracker.requestId}] Generated and saved title: ${title}`)
}
})
.catch((error) => {
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
})
}
// Update chat in database immediately (without blocking for title)

View File

@@ -229,7 +229,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists - override the default empty array
const existingChat = {
id: 'chat-123',
userId: 'user-123',
@@ -267,7 +266,6 @@ describe('Copilot Chat Update Messages API Route', () => {
messageCount: 2,
})
// Verify database operations
expect(mockSelect).toHaveBeenCalled()
expect(mockUpdate).toHaveBeenCalled()
expect(mockSet).toHaveBeenCalledWith({
@@ -280,7 +278,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists
const existingChat = {
id: 'chat-456',
userId: 'user-123',
@@ -341,7 +338,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists
const existingChat = {
id: 'chat-789',
userId: 'user-123',
@@ -374,7 +370,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock database error during chat lookup
mockLimit.mockRejectedValueOnce(new Error('Database connection failed'))
const req = createMockRequest('POST', {
@@ -401,7 +396,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists
const existingChat = {
id: 'chat-123',
userId: 'user-123',
@@ -409,7 +403,6 @@ describe('Copilot Chat Update Messages API Route', () => {
}
mockLimit.mockResolvedValueOnce([existingChat])
// Mock database error during update
mockSet.mockReturnValueOnce({
where: vi.fn().mockRejectedValue(new Error('Update operation failed')),
})
@@ -438,7 +431,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Create a request with invalid JSON
const req = new NextRequest('http://localhost:3000/api/copilot/chat/update-messages', {
method: 'POST',
body: '{invalid-json',
@@ -459,7 +451,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists
const existingChat = {
id: 'chat-large',
userId: 'user-123',
@@ -467,7 +458,6 @@ describe('Copilot Chat Update Messages API Route', () => {
}
mockLimit.mockResolvedValueOnce([existingChat])
// Create a large array of messages
const messages = Array.from({ length: 100 }, (_, i) => ({
id: `msg-${i + 1}`,
role: i % 2 === 0 ? 'user' : 'assistant',
@@ -500,7 +490,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists
const existingChat = {
id: 'chat-mixed',
userId: 'user-123',

View File

@@ -28,7 +28,7 @@ const UpdateMessagesSchema = z.object({
.array(
z.object({
id: z.string(),
s3_key: z.string(),
key: z.string(),
filename: z.string(),
media_type: z.string(),
size: z.number(),

View File

@@ -0,0 +1,39 @@
import { desc, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import {
authenticateCopilotRequestSessionOnly,
createInternalServerErrorResponse,
createUnauthorizedResponse,
} from '@/lib/copilot/auth'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { copilotChats } from '@/db/schema'
const logger = createLogger('CopilotChatsListAPI')
export async function GET(_req: NextRequest) {
try {
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
if (!isAuthenticated || !userId) {
return createUnauthorizedResponse()
}
const chats = await db
.select({
id: copilotChats.id,
title: copilotChats.title,
workflowId: copilotChats.workflowId,
updatedAt: copilotChats.updatedAt,
})
.from(copilotChats)
.where(eq(copilotChats.userId, userId))
.orderBy(desc(copilotChats.updatedAt))
logger.info(`Retrieved ${chats.length} chats for user ${userId}`)
return NextResponse.json({ success: true, chats })
} catch (error) {
logger.error('Error fetching user copilot chats:', error)
return createInternalServerErrorResponse('Failed to fetch user chats')
}
}

View File

@@ -0,0 +1,53 @@
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import {
authenticateCopilotRequestSessionOnly,
createBadRequestResponse,
createInternalServerErrorResponse,
createRequestTracker,
createUnauthorizedResponse,
} from '@/lib/copilot/auth'
import { routeExecution } from '@/lib/copilot/tools/server/router'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('ExecuteCopilotServerToolAPI')
const ExecuteSchema = z.object({
toolName: z.string(),
payload: z.unknown().optional(),
})
export async function POST(req: NextRequest) {
const tracker = createRequestTracker()
try {
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
if (!isAuthenticated || !userId) {
return createUnauthorizedResponse()
}
const body = await req.json()
try {
const preview = JSON.stringify(body).slice(0, 300)
logger.debug(`[${tracker.requestId}] Incoming request body preview`, { preview })
} catch {}
const { toolName, payload } = ExecuteSchema.parse(body)
logger.info(`[${tracker.requestId}] Executing server tool`, { toolName })
const result = await routeExecution(toolName, payload)
try {
const resultPreview = JSON.stringify(result).slice(0, 300)
logger.debug(`[${tracker.requestId}] Server tool result preview`, { toolName, resultPreview })
} catch {}
return NextResponse.json({ success: true, result })
} catch (error) {
if (error instanceof z.ZodError) {
logger.debug(`[${tracker.requestId}] Zod validation error`, { issues: error.issues })
return createBadRequestResponse('Invalid request body for execute-copilot-server-tool')
}
logger.error(`[${tracker.requestId}] Failed to execute server tool:`, error)
return createInternalServerErrorResponse('Failed to execute server tool')
}
}

View File

@@ -1,761 +1,7 @@
/**
* Tests for copilot methods API route
*
* @vitest-environment node
*/
import { NextRequest } from 'next/server'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockRequest,
mockCryptoUuid,
setupCommonApiMocks,
} from '@/app/api/__test-utils__/utils'
import { describe, expect, it } from 'vitest'
describe('Copilot Methods API Route', () => {
const mockRedisGet = vi.fn()
const mockRedisSet = vi.fn()
const mockGetRedisClient = vi.fn()
const mockToolRegistryHas = vi.fn()
const mockToolRegistryGet = vi.fn()
const mockToolRegistryExecute = vi.fn()
const mockToolRegistryGetAvailableIds = vi.fn()
beforeEach(() => {
vi.resetModules()
setupCommonApiMocks()
mockCryptoUuid()
// Mock Redis client
const mockRedisClient = {
get: mockRedisGet,
set: mockRedisSet,
}
mockGetRedisClient.mockReturnValue(mockRedisClient)
mockRedisGet.mockResolvedValue(null)
mockRedisSet.mockResolvedValue('OK')
vi.doMock('@/lib/redis', () => ({
getRedisClient: mockGetRedisClient,
}))
// Mock tool registry
const mockToolRegistry = {
has: mockToolRegistryHas,
get: mockToolRegistryGet,
execute: mockToolRegistryExecute,
getAvailableIds: mockToolRegistryGetAvailableIds,
}
mockToolRegistryHas.mockReturnValue(true)
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: false })
mockToolRegistryExecute.mockResolvedValue({ success: true, data: 'Tool executed successfully' })
mockToolRegistryGetAvailableIds.mockReturnValue(['test-tool', 'another-tool'])
vi.doMock('@/lib/copilot/tools/server-tools/registry', () => ({
copilotToolRegistry: mockToolRegistry,
}))
// Mock environment variables
vi.doMock('@/lib/env', () => ({
env: {
INTERNAL_API_SECRET: 'test-secret-key',
COPILOT_API_KEY: 'test-copilot-key',
},
}))
// Mock setTimeout for polling
vi.spyOn(global, 'setTimeout').mockImplementation((callback, _delay) => {
if (typeof callback === 'function') {
setImmediate(callback)
}
return setTimeout(() => {}, 0) as any
})
// Mock Date.now for timeout control
let mockTime = 1640995200000
vi.spyOn(Date, 'now').mockImplementation(() => {
mockTime += 1000 // Add 1 second each call
return mockTime
})
// Mock crypto.randomUUID for request IDs
vi.spyOn(crypto, 'randomUUID').mockReturnValue('test-request-id')
})
afterEach(() => {
vi.clearAllMocks()
vi.restoreAllMocks()
})
describe('POST', () => {
it('should return 401 when API key is missing', async () => {
const req = createMockRequest('POST', {
methodId: 'test-tool',
params: {},
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(401)
const responseData = await response.json()
expect(responseData).toEqual({
success: false,
error: 'API key required',
})
})
it('should return 401 when API key is invalid', async () => {
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'invalid-key',
},
body: JSON.stringify({
methodId: 'test-tool',
params: {},
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(401)
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(typeof responseData.error).toBe('string')
})
it('should return 401 when internal API key is not configured', async () => {
// Mock environment with no API key
vi.doMock('@/lib/env', () => ({
env: {
INTERNAL_API_SECRET: undefined,
COPILOT_API_KEY: 'test-copilot-key',
},
}))
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'any-key',
},
body: JSON.stringify({
methodId: 'test-tool',
params: {},
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(401)
const responseData = await response.json()
expect(responseData.status).toBeUndefined()
expect(responseData.success).toBe(false)
expect(typeof responseData.error).toBe('string')
})
it('should return 400 for invalid request body - missing methodId', async () => {
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
params: {},
// Missing methodId
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(400)
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toContain('Required')
})
it('should return 400 for empty methodId', async () => {
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: '',
params: {},
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(400)
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toContain('Method ID is required')
})
it('should return 400 when tool is not found in registry', async () => {
mockToolRegistryHas.mockReturnValue(false)
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'unknown-tool',
params: {},
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(400)
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toContain('Unknown method: unknown-tool')
expect(responseData.error).toContain('Available methods: test-tool, another-tool')
})
it('should successfully execute a tool without interruption', async () => {
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'test-tool',
params: { key: 'value' },
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(200)
const responseData = await response.json()
expect(responseData).toEqual({
success: true,
data: 'Tool executed successfully',
})
expect(mockToolRegistryExecute).toHaveBeenCalledWith('test-tool', { key: 'value' })
})
it('should handle tool execution with default empty params', async () => {
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'test-tool',
// No params provided
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(200)
const responseData = await response.json()
expect(responseData).toEqual({
success: true,
data: 'Tool executed successfully',
})
expect(mockToolRegistryExecute).toHaveBeenCalledWith('test-tool', {})
})
it('should return 400 when tool requires interrupt but no toolCallId provided', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: {},
// No toolCallId provided
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(400)
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toBe(
'This tool requires approval but no tool call ID was provided'
)
})
it('should handle tool execution with interrupt - user approval', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
// Mock Redis to return accepted status immediately (simulate quick approval)
mockRedisGet.mockResolvedValue(
JSON.stringify({ status: 'accepted', message: 'User approved' })
)
// Reset Date.now mock to not trigger timeout
let mockTime = 1640995200000
vi.spyOn(Date, 'now').mockImplementation(() => {
mockTime += 100 // Small increment to avoid timeout
return mockTime
})
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: { key: 'value' },
toolCallId: 'tool-call-123',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(200)
const responseData = await response.json()
expect(responseData).toEqual({
success: true,
data: 'Tool executed successfully',
})
// Verify Redis operations
expect(mockRedisSet).toHaveBeenCalledWith(
'tool_call:tool-call-123',
expect.stringContaining('"status":"pending"'),
'EX',
86400
)
expect(mockRedisGet).toHaveBeenCalledWith('tool_call:tool-call-123')
expect(mockToolRegistryExecute).toHaveBeenCalledWith('interrupt-tool', {
key: 'value',
confirmationMessage: 'User approved',
fullData: {
message: 'User approved',
status: 'accepted',
},
})
})
it('should handle tool execution with interrupt - user rejection', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
// Mock Redis to return rejected status
mockRedisGet.mockResolvedValue(
JSON.stringify({ status: 'rejected', message: 'User rejected' })
)
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: {},
toolCallId: 'tool-call-456',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(200) // User rejection returns 200
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toBe(
'The user decided to skip running this tool. This was a user decision.'
)
// Tool should not be executed when rejected
expect(mockToolRegistryExecute).not.toHaveBeenCalled()
})
it('should handle tool execution with interrupt - error status', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
// Mock Redis to return error status
mockRedisGet.mockResolvedValue(
JSON.stringify({ status: 'error', message: 'Tool execution failed' })
)
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: {},
toolCallId: 'tool-call-error',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(500)
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toBe('Tool execution failed')
})
it('should handle tool execution with interrupt - background status', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
// Mock Redis to return background status
mockRedisGet.mockResolvedValue(
JSON.stringify({ status: 'background', message: 'Running in background' })
)
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: {},
toolCallId: 'tool-call-bg',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(200)
const responseData = await response.json()
expect(responseData).toEqual({
success: true,
data: 'Tool executed successfully',
})
expect(mockToolRegistryExecute).toHaveBeenCalled()
})
it('should handle tool execution with interrupt - success status', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
// Mock Redis to return success status
mockRedisGet.mockResolvedValue(
JSON.stringify({ status: 'success', message: 'Completed successfully' })
)
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: {},
toolCallId: 'tool-call-success',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(200)
const responseData = await response.json()
expect(responseData).toEqual({
success: true,
data: 'Tool executed successfully',
})
expect(mockToolRegistryExecute).toHaveBeenCalled()
})
it('should handle tool execution with interrupt - timeout', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
// Mock Redis to never return a status (timeout scenario)
mockRedisGet.mockResolvedValue(null)
// Mock Date.now to trigger timeout quickly
let mockTime = 1640995200000
vi.spyOn(Date, 'now').mockImplementation(() => {
mockTime += 100000 // Add 100 seconds each call to trigger timeout
return mockTime
})
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: {},
toolCallId: 'tool-call-timeout',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(408) // Request Timeout
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toBe('Tool execution request timed out')
expect(mockToolRegistryExecute).not.toHaveBeenCalled()
})
it('should handle unexpected status in interrupt flow', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
// Mock Redis to return unexpected status
mockRedisGet.mockResolvedValue(
JSON.stringify({ status: 'unknown-status', message: 'Unknown' })
)
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: {},
toolCallId: 'tool-call-unknown',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(500)
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toBe('Unexpected tool call status: unknown-status')
})
it('should handle Redis client unavailable for interrupt flow', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
mockGetRedisClient.mockReturnValue(null)
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: {},
toolCallId: 'tool-call-no-redis',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(408) // Timeout due to Redis unavailable
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toBe('Tool execution request timed out')
})
it('should handle no_op tool with confirmation message', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
// Mock Redis to return accepted status with message
mockRedisGet.mockResolvedValue(
JSON.stringify({ status: 'accepted', message: 'Confirmation message' })
)
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'no_op',
params: { existing: 'param' },
toolCallId: 'tool-call-noop',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(200)
// Verify confirmation message was added to params
expect(mockToolRegistryExecute).toHaveBeenCalledWith('no_op', {
existing: 'param',
confirmationMessage: 'Confirmation message',
fullData: {
message: 'Confirmation message',
status: 'accepted',
},
})
})
it('should handle Redis errors in interrupt flow', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
// Mock Redis to throw an error
mockRedisGet.mockRejectedValue(new Error('Redis connection failed'))
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: {},
toolCallId: 'tool-call-redis-error',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(408) // Timeout due to Redis error
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toBe('Tool execution request timed out')
})
it('should handle tool execution failure', async () => {
mockToolRegistryExecute.mockResolvedValue({
success: false,
error: 'Tool execution failed',
})
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'failing-tool',
params: {},
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(200) // Still returns 200, but with success: false
const responseData = await response.json()
expect(responseData).toEqual({
success: false,
error: 'Tool execution failed',
})
})
it('should handle JSON parsing errors in request body', async () => {
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: '{invalid-json',
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(500)
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toContain('JSON')
})
it('should handle tool registry execution throwing an error', async () => {
mockToolRegistryExecute.mockRejectedValue(new Error('Registry execution failed'))
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'error-tool',
params: {},
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(500)
const responseData = await response.json()
expect(responseData.success).toBe(false)
expect(responseData.error).toBe('Registry execution failed')
})
it('should handle old format Redis status (string instead of JSON)', async () => {
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
// Mock Redis to return old format (direct status string)
mockRedisGet.mockResolvedValue('accepted')
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'x-api-key': 'test-secret-key',
},
body: JSON.stringify({
methodId: 'interrupt-tool',
params: {},
toolCallId: 'tool-call-old-format',
}),
})
const { POST } = await import('@/app/api/copilot/methods/route')
const response = await POST(req)
expect(response.status).toBe(200)
const responseData = await response.json()
expect(responseData).toEqual({
success: true,
data: 'Tool executed successfully',
})
expect(mockToolRegistryExecute).toHaveBeenCalled()
})
describe('copilot methods route placeholder', () => {
it('loads test suite', () => {
expect(true).toBe(true)
})
})

View File

@@ -1,395 +0,0 @@
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { copilotToolRegistry } from '@/lib/copilot/tools/server-tools/registry'
import type { NotificationStatus } from '@/lib/copilot/types'
import { checkCopilotApiKey, checkInternalApiKey } from '@/lib/copilot/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { getRedisClient } from '@/lib/redis'
import { createErrorResponse } from '@/app/api/copilot/methods/utils'
const logger = createLogger('CopilotMethodsAPI')
/**
* Add a tool call to Redis with 'pending' status
*/
async function addToolToRedis(toolCallId: string): Promise<void> {
if (!toolCallId) {
logger.warn('addToolToRedis: No tool call ID provided')
return
}
const redis = getRedisClient()
if (!redis) {
logger.warn('addToolToRedis: Redis client not available')
return
}
try {
const key = `tool_call:${toolCallId}`
const status: NotificationStatus = 'pending'
// Store as JSON object for consistency with confirm API
const toolCallData = {
status,
message: null,
timestamp: new Date().toISOString(),
}
// Set with 24 hour expiry (86400 seconds)
await redis.set(key, JSON.stringify(toolCallData), 'EX', 86400)
logger.info('Tool call added to Redis', {
toolCallId,
key,
status,
})
} catch (error) {
logger.error('Failed to add tool call to Redis', {
toolCallId,
error: error instanceof Error ? error.message : 'Unknown error',
})
}
}
/**
* Poll Redis for tool call status updates
* Returns when status changes to 'Accepted' or 'Rejected', or times out after 60 seconds
*/
async function pollRedisForTool(
toolCallId: string
): Promise<{ status: NotificationStatus; message?: string; fullData?: any } | null> {
const redis = getRedisClient()
if (!redis) {
logger.warn('pollRedisForTool: Redis client not available')
return null
}
const key = `tool_call:${toolCallId}`
const timeout = 600000 // 10 minutes for long-running operations
const pollInterval = 1000 // 1 second
const startTime = Date.now()
while (Date.now() - startTime < timeout) {
try {
const redisValue = await redis.get(key)
if (!redisValue) {
// Wait before next poll
await new Promise((resolve) => setTimeout(resolve, pollInterval))
continue
}
let status: NotificationStatus | null = null
let message: string | undefined
let fullData: any = null
// Try to parse as JSON (new format), fallback to string (old format)
try {
const parsedData = JSON.parse(redisValue)
status = parsedData.status as NotificationStatus
message = parsedData.message || undefined
fullData = parsedData // Store the full parsed data
} catch {
// Fallback to old format (direct status string)
status = redisValue as NotificationStatus
}
if (status !== 'pending') {
// Log the message found in redis prominently - always log, even if message is null/undefined
logger.info('Redis poller found non-pending status', {
toolCallId,
foundMessage: message,
messageType: typeof message,
messageIsNull: message === null,
messageIsUndefined: message === undefined,
status,
duration: Date.now() - startTime,
rawRedisValue: redisValue,
})
// Special logging for set environment variables tool when Redis status is found
if (toolCallId && (status === 'accepted' || status === 'rejected')) {
logger.info('SET_ENV_VARS: Redis polling found status update', {
toolCallId,
foundStatus: status,
redisMessage: message,
pollDuration: Date.now() - startTime,
redisKey: `tool_call:${toolCallId}`,
})
}
return { status, message, fullData }
}
// Wait before next poll
await new Promise((resolve) => setTimeout(resolve, pollInterval))
} catch (error) {
logger.error('Error polling Redis for tool call status', {
toolCallId,
error: error instanceof Error ? error.message : 'Unknown error',
})
return null
}
}
logger.warn('Tool call polling timed out', {
toolCallId,
timeout,
})
return null
}
/**
* Handle tool calls that require user interruption/approval
* Returns { approved: boolean, rejected: boolean, error?: boolean, message?: string } to distinguish between rejection, timeout, and error
*/
async function interruptHandler(toolCallId: string): Promise<{
approved: boolean
rejected: boolean
error?: boolean
message?: string
fullData?: any
}> {
if (!toolCallId) {
logger.error('interruptHandler: No tool call ID provided')
return { approved: false, rejected: false, error: true, message: 'No tool call ID provided' }
}
logger.info('Starting interrupt handler for tool call', { toolCallId })
try {
// Step 1: Add tool to Redis with 'pending' status
await addToolToRedis(toolCallId)
// Step 2: Poll Redis for status update
const result = await pollRedisForTool(toolCallId)
if (!result) {
logger.error('Failed to get tool call status or timed out', { toolCallId })
return { approved: false, rejected: false }
}
const { status, message, fullData } = result
if (status === 'rejected') {
logger.info('Tool execution rejected by user', { toolCallId, message })
return { approved: false, rejected: true, message, fullData }
}
if (status === 'accepted') {
logger.info('Tool execution approved by user', { toolCallId, message })
return { approved: true, rejected: false, message, fullData }
}
if (status === 'error') {
logger.error('Tool execution failed with error', { toolCallId, message })
return { approved: false, rejected: false, error: true, message, fullData }
}
if (status === 'background') {
logger.info('Tool execution moved to background', { toolCallId, message })
return { approved: true, rejected: false, message, fullData }
}
if (status === 'success') {
logger.info('Tool execution completed successfully', { toolCallId, message })
return { approved: true, rejected: false, message, fullData }
}
logger.warn('Unexpected tool call status', { toolCallId, status, message })
return {
approved: false,
rejected: false,
error: true,
message: `Unexpected tool call status: ${status}`,
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
logger.error('Error in interrupt handler', {
toolCallId,
error: errorMessage,
})
return {
approved: false,
rejected: false,
error: true,
message: `Interrupt handler error: ${errorMessage}`,
}
}
}
const MethodExecutionSchema = z.object({
methodId: z.string().min(1, 'Method ID is required'),
params: z.record(z.any()).optional().default({}),
toolCallId: z.string().nullable().optional().default(null),
})
/**
* POST /api/copilot/methods
* Execute a method based on methodId with internal API key auth
*/
export async function POST(req: NextRequest) {
const requestId = crypto.randomUUID()
const startTime = Date.now()
try {
// Evaluate both auth schemes; pass if either is valid
const internalAuth = checkInternalApiKey(req)
const copilotAuth = checkCopilotApiKey(req)
const isAuthenticated = !!(internalAuth?.success || copilotAuth?.success)
if (!isAuthenticated) {
const errorMessage = copilotAuth.error || internalAuth.error || 'Authentication failed'
return NextResponse.json(createErrorResponse(errorMessage), {
status: 401,
})
}
const body = await req.json()
const { methodId, params, toolCallId } = MethodExecutionSchema.parse(body)
logger.info(`[${requestId}] Method execution request`, {
methodId,
toolCallId,
hasParams: !!params && Object.keys(params).length > 0,
})
// Check if tool exists in registry
if (!copilotToolRegistry.has(methodId)) {
logger.error(`[${requestId}] Tool not found in registry: ${methodId}`, {
methodId,
toolCallId,
availableTools: copilotToolRegistry.getAvailableIds(),
registrySize: copilotToolRegistry.getAvailableIds().length,
})
return NextResponse.json(
createErrorResponse(
`Unknown method: ${methodId}. Available methods: ${copilotToolRegistry.getAvailableIds().join(', ')}`
),
{ status: 400 }
)
}
logger.info(`[${requestId}] Tool found in registry: ${methodId}`, {
toolCallId,
})
// Check if the tool requires interrupt/approval
const tool = copilotToolRegistry.get(methodId)
if (tool?.requiresInterrupt) {
if (!toolCallId) {
logger.warn(`[${requestId}] Tool requires interrupt but no toolCallId provided`, {
methodId,
})
return NextResponse.json(
createErrorResponse('This tool requires approval but no tool call ID was provided'),
{ status: 400 }
)
}
logger.info(`[${requestId}] Tool requires interrupt, starting approval process`, {
methodId,
toolCallId,
})
// Handle interrupt flow
const { approved, rejected, error, message, fullData } = await interruptHandler(toolCallId)
if (rejected) {
logger.info(`[${requestId}] Tool execution rejected by user`, {
methodId,
toolCallId,
message,
})
return NextResponse.json(
createErrorResponse(
'The user decided to skip running this tool. This was a user decision.'
),
{ status: 200 } // Changed to 200 - user rejection is a valid response
)
}
if (error) {
logger.error(`[${requestId}] Tool execution failed with error`, {
methodId,
toolCallId,
message,
})
return NextResponse.json(
createErrorResponse(message || 'Tool execution failed with unknown error'),
{ status: 500 } // 500 Internal Server Error
)
}
if (!approved) {
logger.warn(`[${requestId}] Tool execution timed out`, {
methodId,
toolCallId,
})
return NextResponse.json(
createErrorResponse('Tool execution request timed out'),
{ status: 408 } // 408 Request Timeout
)
}
logger.info(`[${requestId}] Tool execution approved by user`, {
methodId,
toolCallId,
message,
})
// For tools that need confirmation data, pass the message and/or fullData as parameters
if (message) {
params.confirmationMessage = message
}
if (fullData) {
params.fullData = fullData
}
}
// Execute the tool directly via registry
const result = await copilotToolRegistry.execute(methodId, params)
logger.info(`[${requestId}] Tool execution result:`, {
methodId,
toolCallId,
success: result.success,
hasData: !!result.data,
hasError: !!result.error,
})
const duration = Date.now() - startTime
logger.info(`[${requestId}] Method execution completed: ${methodId}`, {
methodId,
toolCallId,
duration,
success: result.success,
})
return NextResponse.json(result)
} catch (error) {
const duration = Date.now() - startTime
if (error instanceof z.ZodError) {
logger.error(`[${requestId}] Request validation error:`, {
duration,
errors: error.errors,
})
return NextResponse.json(
createErrorResponse(
`Invalid request data: ${error.errors.map((e) => e.message).join(', ')}`
),
{ status: 400 }
)
}
logger.error(`[${requestId}] Unexpected error:`, {
duration,
error: error instanceof Error ? error.message : 'Unknown error',
stack: error instanceof Error ? error.stack : undefined,
})
return NextResponse.json(
createErrorResponse(error instanceof Error ? error.message : 'Internal server error'),
{ status: 500 }
)
}
}

View File

@@ -1,14 +0,0 @@
import type { CopilotToolResponse } from '@/lib/copilot/tools/server-tools/base'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('CopilotMethodsUtils')
/**
* Create a standardized error response
*/
export function createErrorResponse(error: string): CopilotToolResponse {
return {
success: false,
error,
}
}

View File

@@ -0,0 +1,125 @@
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import {
authenticateCopilotRequestSessionOnly,
createBadRequestResponse,
createInternalServerErrorResponse,
createRequestTracker,
createUnauthorizedResponse,
} from '@/lib/copilot/auth'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
const logger = createLogger('CopilotMarkToolCompleteAPI')
// Sim Agent API configuration
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
// Schema for mark-complete request
const MarkCompleteSchema = z.object({
id: z.string(),
name: z.string(),
status: z.number().int(),
message: z.any().optional(),
data: z.any().optional(),
})
/**
* POST /api/copilot/tools/mark-complete
* Proxy to Sim Agent: POST /api/tools/mark-complete
*/
export async function POST(req: NextRequest) {
const tracker = createRequestTracker()
try {
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
if (!isAuthenticated || !userId) {
return createUnauthorizedResponse()
}
const body = await req.json()
// Log raw body shape for diagnostics (avoid dumping huge payloads)
try {
const bodyPreview = JSON.stringify(body).slice(0, 300)
logger.debug(`[${tracker.requestId}] Incoming mark-complete raw body preview`, {
preview: `${bodyPreview}${bodyPreview.length === 300 ? '...' : ''}`,
})
} catch {}
const parsed = MarkCompleteSchema.parse(body)
const messagePreview = (() => {
try {
const s =
typeof parsed.message === 'string' ? parsed.message : JSON.stringify(parsed.message)
return s ? `${s.slice(0, 200)}${s.length > 200 ? '...' : ''}` : undefined
} catch {
return undefined
}
})()
logger.info(`[${tracker.requestId}] Forwarding tool mark-complete`, {
userId,
toolCallId: parsed.id,
toolName: parsed.name,
status: parsed.status,
hasMessage: parsed.message !== undefined,
hasData: parsed.data !== undefined,
messagePreview,
agentUrl: `${SIM_AGENT_API_URL}/api/tools/mark-complete`,
})
const agentRes = await fetch(`${SIM_AGENT_API_URL}/api/tools/mark-complete`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
},
body: JSON.stringify(parsed),
})
// Attempt to parse agent response JSON
let agentJson: any = null
let agentText: string | null = null
try {
agentJson = await agentRes.json()
} catch (_) {
try {
agentText = await agentRes.text()
} catch {}
}
logger.info(`[${tracker.requestId}] Agent responded to mark-complete`, {
status: agentRes.status,
ok: agentRes.ok,
responseJsonPreview: agentJson ? JSON.stringify(agentJson).slice(0, 300) : undefined,
responseTextPreview: agentText ? agentText.slice(0, 300) : undefined,
})
if (agentRes.ok) {
return NextResponse.json({ success: true })
}
const errorMessage =
agentJson?.error || agentText || `Agent responded with status ${agentRes.status}`
const status = agentRes.status >= 500 ? 500 : 400
logger.warn(`[${tracker.requestId}] Mark-complete failed`, {
status,
error: errorMessage,
})
return NextResponse.json({ success: false, error: errorMessage }, { status })
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${tracker.requestId}] Invalid mark-complete request body`, {
issues: error.issues,
})
return createBadRequestResponse('Invalid request body for mark-complete')
}
logger.error(`[${tracker.requestId}] Failed to proxy mark-complete:`, error)
return createInternalServerErrorResponse('Failed to mark tool as complete')
}
}

View File

@@ -109,7 +109,9 @@ export async function PUT(request: NextRequest) {
// If we can't decrypt the existing value, treat as changed and re-encrypt
logger.warn(
`[${requestId}] Could not decrypt existing variable ${key}, re-encrypting`,
{ error: decryptError }
{
error: decryptError,
}
)
variablesToEncrypt[key] = newValue
updatedVariables.push(key)

View File

@@ -1,16 +1,8 @@
import {
AbortMultipartUploadCommand,
CompleteMultipartUploadCommand,
CreateMultipartUploadCommand,
UploadPartCommand,
} from '@aws-sdk/client-s3'
import { getSignedUrl } from '@aws-sdk/s3-request-presigner'
import { type NextRequest, NextResponse } from 'next/server'
import { v4 as uuidv4 } from 'uuid'
import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console/logger'
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
import { S3_KB_CONFIG } from '@/lib/uploads/setup'
import { BLOB_KB_CONFIG } from '@/lib/uploads/setup'
const logger = createLogger('MultipartUploadAPI')
@@ -26,15 +18,6 @@ interface GetPartUrlsRequest {
partNumbers: number[]
}
interface CompleteMultipartRequest {
uploadId: string
key: string
parts: Array<{
ETag: string
PartNumber: number
}>
}
export async function POST(request: NextRequest) {
try {
const session = await getSession()
@@ -44,106 +27,214 @@ export async function POST(request: NextRequest) {
const action = request.nextUrl.searchParams.get('action')
if (!isUsingCloudStorage() || getStorageProvider() !== 's3') {
if (!isUsingCloudStorage()) {
return NextResponse.json(
{ error: 'Multipart upload is only available with S3 storage' },
{ error: 'Multipart upload is only available with cloud storage (S3 or Azure Blob)' },
{ status: 400 }
)
}
const { getS3Client } = await import('@/lib/uploads/s3/s3-client')
const s3Client = getS3Client()
const storageProvider = getStorageProvider()
switch (action) {
case 'initiate': {
const data: InitiateMultipartRequest = await request.json()
const { fileName, contentType } = data
const { fileName, contentType, fileSize } = data
const safeFileName = fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
const uniqueKey = `kb/${uuidv4()}-${safeFileName}`
if (storageProvider === 's3') {
const { initiateS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
const command = new CreateMultipartUploadCommand({
Bucket: S3_KB_CONFIG.bucket,
Key: uniqueKey,
ContentType: contentType,
Metadata: {
originalName: fileName,
uploadedAt: new Date().toISOString(),
purpose: 'knowledge-base',
},
})
const result = await initiateS3MultipartUpload({
fileName,
contentType,
fileSize,
})
const response = await s3Client.send(command)
logger.info(`Initiated S3 multipart upload for ${fileName}: ${result.uploadId}`)
logger.info(`Initiated multipart upload for ${fileName}: ${response.UploadId}`)
return NextResponse.json({
uploadId: result.uploadId,
key: result.key,
})
}
if (storageProvider === 'blob') {
const { initiateMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
return NextResponse.json({
uploadId: response.UploadId,
key: uniqueKey,
})
const result = await initiateMultipartUpload({
fileName,
contentType,
fileSize,
customConfig: {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
},
})
logger.info(`Initiated Azure multipart upload for ${fileName}: ${result.uploadId}`)
return NextResponse.json({
uploadId: result.uploadId,
key: result.key,
})
}
return NextResponse.json(
{ error: `Unsupported storage provider: ${storageProvider}` },
{ status: 400 }
)
}
case 'get-part-urls': {
const data: GetPartUrlsRequest = await request.json()
const { uploadId, key, partNumbers } = data
const presignedUrls = await Promise.all(
partNumbers.map(async (partNumber) => {
const command = new UploadPartCommand({
Bucket: S3_KB_CONFIG.bucket,
Key: key,
PartNumber: partNumber,
UploadId: uploadId,
})
if (storageProvider === 's3') {
const { getS3MultipartPartUrls } = await import('@/lib/uploads/s3/s3-client')
const url = await getSignedUrl(s3Client, command, { expiresIn: 3600 })
return { partNumber, url }
const presignedUrls = await getS3MultipartPartUrls(key, uploadId, partNumbers)
return NextResponse.json({ presignedUrls })
}
if (storageProvider === 'blob') {
const { getMultipartPartUrls } = await import('@/lib/uploads/blob/blob-client')
const presignedUrls = await getMultipartPartUrls(key, uploadId, partNumbers, {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
})
)
return NextResponse.json({ presignedUrls })
return NextResponse.json({ presignedUrls })
}
return NextResponse.json(
{ error: `Unsupported storage provider: ${storageProvider}` },
{ status: 400 }
)
}
case 'complete': {
const data: CompleteMultipartRequest = await request.json()
const data = await request.json()
// Handle batch completion
if ('uploads' in data) {
const results = await Promise.all(
data.uploads.map(async (upload: any) => {
const { uploadId, key } = upload
if (storageProvider === 's3') {
const { completeS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
const parts = upload.parts // S3 format: { ETag, PartNumber }
const result = await completeS3MultipartUpload(key, uploadId, parts)
return {
success: true,
location: result.location,
path: result.path,
key: result.key,
}
}
if (storageProvider === 'blob') {
const { completeMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
const parts = upload.parts // Azure format: { blockId, partNumber }
const result = await completeMultipartUpload(key, uploadId, parts, {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
})
return {
success: true,
location: result.location,
path: result.path,
key: result.key,
}
}
throw new Error(`Unsupported storage provider: ${storageProvider}`)
})
)
logger.info(`Completed ${data.uploads.length} multipart uploads`)
return NextResponse.json({ results })
}
// Handle single completion
const { uploadId, key, parts } = data
const command = new CompleteMultipartUploadCommand({
Bucket: S3_KB_CONFIG.bucket,
Key: key,
UploadId: uploadId,
MultipartUpload: {
Parts: parts.sort((a, b) => a.PartNumber - b.PartNumber),
},
})
if (storageProvider === 's3') {
const { completeS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
const response = await s3Client.send(command)
const result = await completeS3MultipartUpload(key, uploadId, parts)
logger.info(`Completed multipart upload for key ${key}`)
logger.info(`Completed S3 multipart upload for key ${key}`)
const finalPath = `/api/files/serve/s3/${encodeURIComponent(key)}`
return NextResponse.json({
success: true,
location: result.location,
path: result.path,
key: result.key,
})
}
if (storageProvider === 'blob') {
const { completeMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
return NextResponse.json({
success: true,
location: response.Location,
path: finalPath,
key,
})
const result = await completeMultipartUpload(key, uploadId, parts, {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
})
logger.info(`Completed Azure multipart upload for key ${key}`)
return NextResponse.json({
success: true,
location: result.location,
path: result.path,
key: result.key,
})
}
return NextResponse.json(
{ error: `Unsupported storage provider: ${storageProvider}` },
{ status: 400 }
)
}
case 'abort': {
const data = await request.json()
const { uploadId, key } = data
const command = new AbortMultipartUploadCommand({
Bucket: S3_KB_CONFIG.bucket,
Key: key,
UploadId: uploadId,
})
if (storageProvider === 's3') {
const { abortS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
await s3Client.send(command)
await abortS3MultipartUpload(key, uploadId)
logger.info(`Aborted multipart upload for key ${key}`)
logger.info(`Aborted S3 multipart upload for key ${key}`)
} else if (storageProvider === 'blob') {
const { abortMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
await abortMultipartUpload(key, uploadId, {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
})
logger.info(`Aborted Azure multipart upload for key ${key}`)
} else {
return NextResponse.json(
{ error: `Unsupported storage provider: ${storageProvider}` },
{ status: 400 }
)
}
return NextResponse.json({ success: true })
}

View File

@@ -0,0 +1,361 @@
import { PutObjectCommand } from '@aws-sdk/client-s3'
import { getSignedUrl } from '@aws-sdk/s3-request-presigner'
import { type NextRequest, NextResponse } from 'next/server'
import { v4 as uuidv4 } from 'uuid'
import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console/logger'
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
import {
BLOB_CHAT_CONFIG,
BLOB_CONFIG,
BLOB_COPILOT_CONFIG,
BLOB_KB_CONFIG,
S3_CHAT_CONFIG,
S3_CONFIG,
S3_COPILOT_CONFIG,
S3_KB_CONFIG,
} from '@/lib/uploads/setup'
import { validateFileType } from '@/lib/uploads/validation'
import { createErrorResponse, createOptionsResponse } from '@/app/api/files/utils'
const logger = createLogger('BatchPresignedUploadAPI')
interface BatchFileRequest {
fileName: string
contentType: string
fileSize: number
}
interface BatchPresignedUrlRequest {
files: BatchFileRequest[]
}
type UploadType = 'general' | 'knowledge-base' | 'chat' | 'copilot'
export async function POST(request: NextRequest) {
try {
const session = await getSession()
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
let data: BatchPresignedUrlRequest
try {
data = await request.json()
} catch {
return NextResponse.json({ error: 'Invalid JSON in request body' }, { status: 400 })
}
const { files } = data
if (!files || !Array.isArray(files) || files.length === 0) {
return NextResponse.json(
{ error: 'files array is required and cannot be empty' },
{ status: 400 }
)
}
if (files.length > 100) {
return NextResponse.json(
{ error: 'Cannot process more than 100 files at once' },
{ status: 400 }
)
}
const uploadTypeParam = request.nextUrl.searchParams.get('type')
const uploadType: UploadType =
uploadTypeParam === 'knowledge-base'
? 'knowledge-base'
: uploadTypeParam === 'chat'
? 'chat'
: uploadTypeParam === 'copilot'
? 'copilot'
: 'general'
const MAX_FILE_SIZE = 100 * 1024 * 1024
for (const file of files) {
if (!file.fileName?.trim()) {
return NextResponse.json({ error: 'fileName is required for all files' }, { status: 400 })
}
if (!file.contentType?.trim()) {
return NextResponse.json(
{ error: 'contentType is required for all files' },
{ status: 400 }
)
}
if (!file.fileSize || file.fileSize <= 0) {
return NextResponse.json(
{ error: 'fileSize must be positive for all files' },
{ status: 400 }
)
}
if (file.fileSize > MAX_FILE_SIZE) {
return NextResponse.json(
{ error: `File ${file.fileName} exceeds maximum size of ${MAX_FILE_SIZE} bytes` },
{ status: 400 }
)
}
if (uploadType === 'knowledge-base') {
const fileValidationError = validateFileType(file.fileName, file.contentType)
if (fileValidationError) {
return NextResponse.json(
{
error: fileValidationError.message,
code: fileValidationError.code,
supportedTypes: fileValidationError.supportedTypes,
},
{ status: 400 }
)
}
}
}
const sessionUserId = session.user.id
if (uploadType === 'copilot' && !sessionUserId?.trim()) {
return NextResponse.json(
{ error: 'Authenticated user session is required for copilot uploads' },
{ status: 400 }
)
}
if (!isUsingCloudStorage()) {
return NextResponse.json(
{ error: 'Direct uploads are only available when cloud storage is enabled' },
{ status: 400 }
)
}
const storageProvider = getStorageProvider()
logger.info(
`Generating batch ${uploadType} presigned URLs for ${files.length} files using ${storageProvider}`
)
const startTime = Date.now()
let result
switch (storageProvider) {
case 's3':
result = await handleBatchS3PresignedUrls(files, uploadType, sessionUserId)
break
case 'blob':
result = await handleBatchBlobPresignedUrls(files, uploadType, sessionUserId)
break
default:
return NextResponse.json(
{ error: `Unknown storage provider: ${storageProvider}` },
{ status: 500 }
)
}
const duration = Date.now() - startTime
logger.info(
`Generated ${files.length} presigned URLs in ${duration}ms (avg ${Math.round(duration / files.length)}ms per file)`
)
return NextResponse.json(result)
} catch (error) {
logger.error('Error generating batch presigned URLs:', error)
return createErrorResponse(
error instanceof Error ? error : new Error('Failed to generate batch presigned URLs')
)
}
}
async function handleBatchS3PresignedUrls(
files: BatchFileRequest[],
uploadType: UploadType,
userId?: string
) {
const config =
uploadType === 'knowledge-base'
? S3_KB_CONFIG
: uploadType === 'chat'
? S3_CHAT_CONFIG
: uploadType === 'copilot'
? S3_COPILOT_CONFIG
: S3_CONFIG
if (!config.bucket || !config.region) {
throw new Error(`S3 configuration missing for ${uploadType} uploads`)
}
const { getS3Client, sanitizeFilenameForMetadata } = await import('@/lib/uploads/s3/s3-client')
const s3Client = getS3Client()
let prefix = ''
if (uploadType === 'knowledge-base') {
prefix = 'kb/'
} else if (uploadType === 'chat') {
prefix = 'chat/'
} else if (uploadType === 'copilot') {
prefix = `${userId}/`
}
const baseMetadata: Record<string, string> = {
uploadedAt: new Date().toISOString(),
}
if (uploadType === 'knowledge-base') {
baseMetadata.purpose = 'knowledge-base'
} else if (uploadType === 'chat') {
baseMetadata.purpose = 'chat'
} else if (uploadType === 'copilot') {
baseMetadata.purpose = 'copilot'
baseMetadata.userId = userId || ''
}
const results = await Promise.all(
files.map(async (file) => {
const safeFileName = file.fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
const uniqueKey = `${prefix}${uuidv4()}-${safeFileName}`
const sanitizedOriginalName = sanitizeFilenameForMetadata(file.fileName)
const metadata = {
...baseMetadata,
originalName: sanitizedOriginalName,
}
const command = new PutObjectCommand({
Bucket: config.bucket,
Key: uniqueKey,
ContentType: file.contentType,
Metadata: metadata,
})
const presignedUrl = await getSignedUrl(s3Client, command, { expiresIn: 3600 })
const finalPath =
uploadType === 'chat'
? `https://${config.bucket}.s3.${config.region}.amazonaws.com/${uniqueKey}`
: `/api/files/serve/s3/${encodeURIComponent(uniqueKey)}`
return {
fileName: file.fileName,
presignedUrl,
fileInfo: {
path: finalPath,
key: uniqueKey,
name: file.fileName,
size: file.fileSize,
type: file.contentType,
},
}
})
)
return {
files: results,
directUploadSupported: true,
}
}
async function handleBatchBlobPresignedUrls(
files: BatchFileRequest[],
uploadType: UploadType,
userId?: string
) {
const config =
uploadType === 'knowledge-base'
? BLOB_KB_CONFIG
: uploadType === 'chat'
? BLOB_CHAT_CONFIG
: uploadType === 'copilot'
? BLOB_COPILOT_CONFIG
: BLOB_CONFIG
if (
!config.accountName ||
!config.containerName ||
(!config.accountKey && !config.connectionString)
) {
throw new Error(`Azure Blob configuration missing for ${uploadType} uploads`)
}
const { getBlobServiceClient } = await import('@/lib/uploads/blob/blob-client')
const { BlobSASPermissions, generateBlobSASQueryParameters, StorageSharedKeyCredential } =
await import('@azure/storage-blob')
const blobServiceClient = getBlobServiceClient()
const containerClient = blobServiceClient.getContainerClient(config.containerName)
let prefix = ''
if (uploadType === 'knowledge-base') {
prefix = 'kb/'
} else if (uploadType === 'chat') {
prefix = 'chat/'
} else if (uploadType === 'copilot') {
prefix = `${userId}/`
}
const baseUploadHeaders: Record<string, string> = {
'x-ms-blob-type': 'BlockBlob',
'x-ms-meta-uploadedat': new Date().toISOString(),
}
if (uploadType === 'knowledge-base') {
baseUploadHeaders['x-ms-meta-purpose'] = 'knowledge-base'
} else if (uploadType === 'chat') {
baseUploadHeaders['x-ms-meta-purpose'] = 'chat'
} else if (uploadType === 'copilot') {
baseUploadHeaders['x-ms-meta-purpose'] = 'copilot'
baseUploadHeaders['x-ms-meta-userid'] = encodeURIComponent(userId || '')
}
const results = await Promise.all(
files.map(async (file) => {
const safeFileName = file.fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
const uniqueKey = `${prefix}${uuidv4()}-${safeFileName}`
const blockBlobClient = containerClient.getBlockBlobClient(uniqueKey)
const sasOptions = {
containerName: config.containerName,
blobName: uniqueKey,
permissions: BlobSASPermissions.parse('w'),
startsOn: new Date(),
expiresOn: new Date(Date.now() + 3600 * 1000),
}
const sasToken = generateBlobSASQueryParameters(
sasOptions,
new StorageSharedKeyCredential(config.accountName, config.accountKey || '')
).toString()
const presignedUrl = `${blockBlobClient.url}?${sasToken}`
const finalPath =
uploadType === 'chat'
? blockBlobClient.url
: `/api/files/serve/blob/${encodeURIComponent(uniqueKey)}`
const uploadHeaders = {
...baseUploadHeaders,
'x-ms-blob-content-type': file.contentType,
'x-ms-meta-originalname': encodeURIComponent(file.fileName),
}
return {
fileName: file.fileName,
presignedUrl,
fileInfo: {
path: finalPath,
key: uniqueKey,
name: file.fileName,
size: file.fileSize,
type: file.contentType,
},
uploadHeaders,
}
})
)
return {
files: results,
directUploadSupported: true,
}
}
export async function OPTIONS() {
return createOptionsResponse()
}

View File

@@ -5,6 +5,7 @@ import { v4 as uuidv4 } from 'uuid'
import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console/logger'
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
import { isImageFileType } from '@/lib/uploads/file-utils'
// Dynamic imports for storage clients to avoid client-side bundling
import {
BLOB_CHAT_CONFIG,
@@ -16,6 +17,7 @@ import {
S3_COPILOT_CONFIG,
S3_KB_CONFIG,
} from '@/lib/uploads/setup'
import { validateFileType } from '@/lib/uploads/validation'
import { createErrorResponse, createOptionsResponse } from '@/app/api/files/utils'
const logger = createLogger('PresignedUploadAPI')
@@ -96,6 +98,13 @@ export async function POST(request: NextRequest) {
? 'copilot'
: 'general'
if (uploadType === 'knowledge-base') {
const fileValidationError = validateFileType(fileName, contentType)
if (fileValidationError) {
throw new ValidationError(`${fileValidationError.message}`)
}
}
// Evaluate user id from session for copilot uploads
const sessionUserId = session.user.id
@@ -104,6 +113,12 @@ export async function POST(request: NextRequest) {
if (!sessionUserId?.trim()) {
throw new ValidationError('Authenticated user session is required for copilot uploads')
}
// Only allow image uploads for copilot
if (!isImageFileType(contentType)) {
throw new ValidationError(
'Only image files (JPEG, PNG, GIF, WebP, SVG) are allowed for copilot uploads'
)
}
}
if (!isUsingCloudStorage()) {
@@ -224,10 +239,9 @@ async function handleS3PresignedUrl(
)
}
// For chat images, use direct S3 URLs since they need to be permanently accessible
// For other files, use serve path for access control
// For chat images and knowledge base files, use direct URLs since they need to be accessible by external services
const finalPath =
uploadType === 'chat'
uploadType === 'chat' || uploadType === 'knowledge-base'
? `https://${config.bucket}.s3.${config.region}.amazonaws.com/${uniqueKey}`
: `/api/files/serve/s3/${encodeURIComponent(uniqueKey)}`

View File

@@ -2,7 +2,7 @@ import { readFile } from 'fs/promises'
import type { NextRequest, NextResponse } from 'next/server'
import { createLogger } from '@/lib/logs/console/logger'
import { downloadFile, getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
import { BLOB_KB_CONFIG, S3_KB_CONFIG } from '@/lib/uploads/setup'
import { S3_KB_CONFIG } from '@/lib/uploads/setup'
import '@/lib/uploads/setup.server'
import {
@@ -15,19 +15,6 @@ import {
const logger = createLogger('FilesServeAPI')
async function streamToBuffer(readableStream: NodeJS.ReadableStream): Promise<Buffer> {
return new Promise((resolve, reject) => {
const chunks: Buffer[] = []
readableStream.on('data', (data) => {
chunks.push(data instanceof Buffer ? data : Buffer.from(data))
})
readableStream.on('end', () => {
resolve(Buffer.concat(chunks))
})
readableStream.on('error', reject)
})
}
/**
* Main API route handler for serving files
*/
@@ -102,49 +89,23 @@ async function handleLocalFile(filename: string): Promise<NextResponse> {
}
async function downloadKBFile(cloudKey: string): Promise<Buffer> {
logger.info(`Downloading KB file: ${cloudKey}`)
const storageProvider = getStorageProvider()
if (storageProvider === 'blob') {
logger.info(`Downloading KB file from Azure Blob Storage: ${cloudKey}`)
// Use KB-specific blob configuration
const { getBlobServiceClient } = await import('@/lib/uploads/blob/blob-client')
const blobServiceClient = getBlobServiceClient()
const containerClient = blobServiceClient.getContainerClient(BLOB_KB_CONFIG.containerName)
const blockBlobClient = containerClient.getBlockBlobClient(cloudKey)
const downloadBlockBlobResponse = await blockBlobClient.download()
if (!downloadBlockBlobResponse.readableStreamBody) {
throw new Error('Failed to get readable stream from blob download')
}
// Convert stream to buffer
return await streamToBuffer(downloadBlockBlobResponse.readableStreamBody)
const { BLOB_KB_CONFIG } = await import('@/lib/uploads/setup')
return downloadFile(cloudKey, {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
})
}
if (storageProvider === 's3') {
logger.info(`Downloading KB file from S3: ${cloudKey}`)
// Use KB-specific S3 configuration
const { getS3Client } = await import('@/lib/uploads/s3/s3-client')
const { GetObjectCommand } = await import('@aws-sdk/client-s3')
const s3Client = getS3Client()
const command = new GetObjectCommand({
Bucket: S3_KB_CONFIG.bucket,
Key: cloudKey,
})
const response = await s3Client.send(command)
if (!response.Body) {
throw new Error('No body in S3 response')
}
// Convert stream to buffer using the same method as the regular S3 client
const stream = response.Body as any
return new Promise<Buffer>((resolve, reject) => {
const chunks: Buffer[] = []
stream.on('data', (chunk: Buffer) => chunks.push(chunk))
stream.on('end', () => resolve(Buffer.concat(chunks)))
stream.on('error', reject)
return downloadFile(cloudKey, {
bucket: S3_KB_CONFIG.bucket,
region: S3_KB_CONFIG.region,
})
}
@@ -167,17 +128,22 @@ async function handleCloudProxy(
if (isKBFile) {
fileBuffer = await downloadKBFile(cloudKey)
} else if (bucketType === 'copilot') {
// Download from copilot-specific bucket
const storageProvider = getStorageProvider()
if (storageProvider === 's3') {
const { downloadFromS3WithConfig } = await import('@/lib/uploads/s3/s3-client')
const { S3_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
fileBuffer = await downloadFromS3WithConfig(cloudKey, S3_COPILOT_CONFIG)
fileBuffer = await downloadFile(cloudKey, {
bucket: S3_COPILOT_CONFIG.bucket,
region: S3_COPILOT_CONFIG.region,
})
} else if (storageProvider === 'blob') {
// For Azure Blob, use the default downloadFile for now
// TODO: Add downloadFromBlobWithConfig when needed
fileBuffer = await downloadFile(cloudKey)
const { BLOB_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
fileBuffer = await downloadFile(cloudKey, {
containerName: BLOB_COPILOT_CONFIG.containerName,
accountName: BLOB_COPILOT_CONFIG.accountName,
accountKey: BLOB_COPILOT_CONFIG.accountKey,
connectionString: BLOB_COPILOT_CONFIG.connectionString,
})
} else {
fileBuffer = await downloadFile(cloudKey)
}

View File

@@ -186,3 +186,190 @@ describe('File Upload API Route', () => {
expect(response.headers.get('Access-Control-Allow-Headers')).toBe('Content-Type')
})
})
describe('File Upload Security Tests', () => {
beforeEach(() => {
vi.resetModules()
vi.clearAllMocks()
vi.doMock('@/lib/auth', () => ({
getSession: vi.fn().mockResolvedValue({
user: { id: 'test-user-id' },
}),
}))
vi.doMock('@/lib/uploads', () => ({
isUsingCloudStorage: vi.fn().mockReturnValue(false),
uploadFile: vi.fn().mockResolvedValue({
key: 'test-key',
path: '/test/path',
}),
}))
vi.doMock('@/lib/uploads/setup.server', () => ({}))
})
afterEach(() => {
vi.clearAllMocks()
})
describe('File Extension Validation', () => {
it('should accept allowed file types', async () => {
const allowedTypes = [
'pdf',
'doc',
'docx',
'txt',
'md',
'png',
'jpg',
'jpeg',
'gif',
'csv',
'xlsx',
'xls',
]
for (const ext of allowedTypes) {
const formData = new FormData()
const file = new File(['test content'], `test.${ext}`, { type: 'application/octet-stream' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(200)
}
})
it('should reject HTML files to prevent XSS', async () => {
const formData = new FormData()
const maliciousContent = '<script>alert("XSS")</script>'
const file = new File([maliciousContent], 'malicious.html', { type: 'text/html' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(400)
const data = await response.json()
expect(data.message).toContain("File type 'html' is not allowed")
})
it('should reject SVG files to prevent XSS', async () => {
const formData = new FormData()
const maliciousSvg = '<svg onload="alert(\'XSS\')" xmlns="http://www.w3.org/2000/svg"></svg>'
const file = new File([maliciousSvg], 'malicious.svg', { type: 'image/svg+xml' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(400)
const data = await response.json()
expect(data.message).toContain("File type 'svg' is not allowed")
})
it('should reject JavaScript files', async () => {
const formData = new FormData()
const maliciousJs = 'alert("XSS")'
const file = new File([maliciousJs], 'malicious.js', { type: 'application/javascript' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(400)
const data = await response.json()
expect(data.message).toContain("File type 'js' is not allowed")
})
it('should reject files without extensions', async () => {
const formData = new FormData()
const file = new File(['test content'], 'noextension', { type: 'application/octet-stream' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(400)
const data = await response.json()
expect(data.message).toContain("File type 'noextension' is not allowed")
})
it('should handle multiple files with mixed valid/invalid types', async () => {
const formData = new FormData()
// Valid file
const validFile = new File(['valid content'], 'valid.pdf', { type: 'application/pdf' })
formData.append('file', validFile)
// Invalid file (should cause rejection of entire request)
const invalidFile = new File(['<script>alert("XSS")</script>'], 'malicious.html', {
type: 'text/html',
})
formData.append('file', invalidFile)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(400)
const data = await response.json()
expect(data.message).toContain("File type 'html' is not allowed")
})
})
describe('Authentication Requirements', () => {
it('should reject uploads without authentication', async () => {
vi.doMock('@/lib/auth', () => ({
getSession: vi.fn().mockResolvedValue(null),
}))
const formData = new FormData()
const file = new File(['test content'], 'test.pdf', { type: 'application/pdf' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(401)
const data = await response.json()
expect(data.error).toBe('Unauthorized')
})
})
})

View File

@@ -9,6 +9,34 @@ import {
InvalidRequestError,
} from '@/app/api/files/utils'
// Allowlist of permitted file extensions for security
const ALLOWED_EXTENSIONS = new Set([
// Documents
'pdf',
'doc',
'docx',
'txt',
'md',
// Images (safe formats)
'png',
'jpg',
'jpeg',
'gif',
// Data files
'csv',
'xlsx',
'xls',
])
/**
* Validates file extension against allowlist
*/
function validateFileExtension(filename: string): boolean {
const extension = filename.split('.').pop()?.toLowerCase()
if (!extension) return false
return ALLOWED_EXTENSIONS.has(extension)
}
export const dynamic = 'force-dynamic'
const logger = createLogger('FilesUploadAPI')
@@ -49,6 +77,14 @@ export async function POST(request: NextRequest) {
// Process each file
for (const file of files) {
const originalName = file.name
if (!validateFileExtension(originalName)) {
const extension = originalName.split('.').pop()?.toLowerCase() || 'unknown'
throw new InvalidRequestError(
`File type '${extension}' is not allowed. Allowed types: ${Array.from(ALLOWED_EXTENSIONS).join(', ')}`
)
}
const bytes = await file.arrayBuffer()
const buffer = Buffer.from(bytes)

View File

@@ -0,0 +1,327 @@
import { describe, expect, it } from 'vitest'
import { createFileResponse, extractFilename } from './utils'
describe('extractFilename', () => {
describe('legitimate file paths', () => {
it('should extract filename from standard serve path', () => {
expect(extractFilename('/api/files/serve/test-file.txt')).toBe('test-file.txt')
})
it('should extract filename from serve path with special characters', () => {
expect(extractFilename('/api/files/serve/document-with-dashes_and_underscores.pdf')).toBe(
'document-with-dashes_and_underscores.pdf'
)
})
it('should handle simple filename without serve path', () => {
expect(extractFilename('simple-file.txt')).toBe('simple-file.txt')
})
it('should extract last segment from nested path', () => {
expect(extractFilename('nested/path/file.txt')).toBe('file.txt')
})
})
describe('cloud storage paths', () => {
it('should preserve S3 path structure', () => {
expect(extractFilename('/api/files/serve/s3/1234567890-test-file.txt')).toBe(
's3/1234567890-test-file.txt'
)
})
it('should preserve S3 path with nested folders', () => {
expect(extractFilename('/api/files/serve/s3/folder/subfolder/document.pdf')).toBe(
's3/folder/subfolder/document.pdf'
)
})
it('should preserve Azure Blob path structure', () => {
expect(extractFilename('/api/files/serve/blob/1234567890-test-document.pdf')).toBe(
'blob/1234567890-test-document.pdf'
)
})
it('should preserve Blob path with nested folders', () => {
expect(extractFilename('/api/files/serve/blob/uploads/user-files/report.xlsx')).toBe(
'blob/uploads/user-files/report.xlsx'
)
})
})
describe('security - path traversal prevention', () => {
it('should sanitize basic path traversal attempt', () => {
expect(extractFilename('/api/files/serve/../config.txt')).toBe('config.txt')
})
it('should sanitize deep path traversal attempt', () => {
expect(extractFilename('/api/files/serve/../../../../../etc/passwd')).toBe('etcpasswd')
})
it('should sanitize multiple path traversal patterns', () => {
expect(extractFilename('/api/files/serve/../../secret.txt')).toBe('secret.txt')
})
it('should sanitize path traversal with forward slashes', () => {
expect(extractFilename('/api/files/serve/../../../system/file')).toBe('systemfile')
})
it('should sanitize mixed path traversal patterns', () => {
expect(extractFilename('/api/files/serve/../folder/../file.txt')).toBe('folderfile.txt')
})
it('should remove directory separators from local filenames', () => {
expect(extractFilename('/api/files/serve/folder/with/separators.txt')).toBe(
'folderwithseparators.txt'
)
})
it('should handle backslash path separators (Windows style)', () => {
expect(extractFilename('/api/files/serve/folder\\file.txt')).toBe('folderfile.txt')
})
})
describe('cloud storage path traversal prevention', () => {
it('should sanitize S3 path traversal attempts while preserving structure', () => {
expect(extractFilename('/api/files/serve/s3/../config')).toBe('s3/config')
})
it('should sanitize S3 path with nested traversal attempts', () => {
expect(extractFilename('/api/files/serve/s3/folder/../sensitive/../file.txt')).toBe(
's3/folder/sensitive/file.txt'
)
})
it('should sanitize Blob path traversal attempts while preserving structure', () => {
expect(extractFilename('/api/files/serve/blob/../system.txt')).toBe('blob/system.txt')
})
it('should remove leading dots from cloud path segments', () => {
expect(extractFilename('/api/files/serve/s3/.hidden/../file.txt')).toBe('s3/hidden/file.txt')
})
})
describe('edge cases and error handling', () => {
it('should handle filename with dots (but not traversal)', () => {
expect(extractFilename('/api/files/serve/file.with.dots.txt')).toBe('file.with.dots.txt')
})
it('should handle filename with multiple extensions', () => {
expect(extractFilename('/api/files/serve/archive.tar.gz')).toBe('archive.tar.gz')
})
it('should throw error for empty filename after sanitization', () => {
expect(() => extractFilename('/api/files/serve/')).toThrow(
'Invalid or empty filename after sanitization'
)
})
it('should throw error for filename that becomes empty after path traversal removal', () => {
expect(() => extractFilename('/api/files/serve/../..')).toThrow(
'Invalid or empty filename after sanitization'
)
})
it('should handle single character filenames', () => {
expect(extractFilename('/api/files/serve/a')).toBe('a')
})
it('should handle numeric filenames', () => {
expect(extractFilename('/api/files/serve/123')).toBe('123')
})
})
describe('backward compatibility', () => {
it('should match old behavior for legitimate local files', () => {
// These test cases verify that our security fix maintains exact backward compatibility
// for all legitimate use cases found in the existing codebase
expect(extractFilename('/api/files/serve/test-file.txt')).toBe('test-file.txt')
expect(extractFilename('/api/files/serve/nonexistent.txt')).toBe('nonexistent.txt')
})
it('should match old behavior for legitimate cloud files', () => {
// These test cases are from the actual delete route tests
expect(extractFilename('/api/files/serve/s3/1234567890-test-file.txt')).toBe(
's3/1234567890-test-file.txt'
)
expect(extractFilename('/api/files/serve/blob/1234567890-test-document.pdf')).toBe(
'blob/1234567890-test-document.pdf'
)
})
it('should match old behavior for simple paths', () => {
// These match the mock implementations in serve route tests
expect(extractFilename('simple-file.txt')).toBe('simple-file.txt')
expect(extractFilename('nested/path/file.txt')).toBe('file.txt')
})
})
describe('File Serving Security Tests', () => {
describe('createFileResponse security headers', () => {
it('should serve safe images inline with proper headers', () => {
const response = createFileResponse({
buffer: Buffer.from('fake-image-data'),
contentType: 'image/png',
filename: 'safe-image.png',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('image/png')
expect(response.headers.get('Content-Disposition')).toBe(
'inline; filename="safe-image.png"'
)
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
expect(response.headers.get('Content-Security-Policy')).toBe(
"default-src 'none'; style-src 'unsafe-inline'; sandbox;"
)
})
it('should serve PDFs inline safely', () => {
const response = createFileResponse({
buffer: Buffer.from('fake-pdf-data'),
contentType: 'application/pdf',
filename: 'document.pdf',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/pdf')
expect(response.headers.get('Content-Disposition')).toBe('inline; filename="document.pdf"')
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
})
it('should force attachment for HTML files to prevent XSS', () => {
const response = createFileResponse({
buffer: Buffer.from('<script>alert("XSS")</script>'),
contentType: 'text/html',
filename: 'malicious.html',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="malicious.html"'
)
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
})
it('should force attachment for SVG files to prevent XSS', () => {
const response = createFileResponse({
buffer: Buffer.from(
'<svg onload="alert(\'XSS\')" xmlns="http://www.w3.org/2000/svg"></svg>'
),
contentType: 'image/svg+xml',
filename: 'malicious.svg',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="malicious.svg"'
)
})
it('should override dangerous content types to safe alternatives', () => {
const response = createFileResponse({
buffer: Buffer.from('<svg>safe content</svg>'),
contentType: 'image/svg+xml',
filename: 'image.png', // Extension doesn't match content-type
})
expect(response.status).toBe(200)
// Should override SVG content type to plain text for safety
expect(response.headers.get('Content-Type')).toBe('text/plain')
expect(response.headers.get('Content-Disposition')).toBe('inline; filename="image.png"')
})
it('should force attachment for JavaScript files', () => {
const response = createFileResponse({
buffer: Buffer.from('alert("XSS")'),
contentType: 'application/javascript',
filename: 'malicious.js',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="malicious.js"'
)
})
it('should force attachment for CSS files', () => {
const response = createFileResponse({
buffer: Buffer.from('body { background: url(javascript:alert("XSS")) }'),
contentType: 'text/css',
filename: 'malicious.css',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="malicious.css"'
)
})
it('should force attachment for XML files', () => {
const response = createFileResponse({
buffer: Buffer.from('<?xml version="1.0"?><root><script>alert("XSS")</script></root>'),
contentType: 'application/xml',
filename: 'malicious.xml',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="malicious.xml"'
)
})
it('should serve text files safely', () => {
const response = createFileResponse({
buffer: Buffer.from('Safe text content'),
contentType: 'text/plain',
filename: 'document.txt',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('text/plain')
expect(response.headers.get('Content-Disposition')).toBe('inline; filename="document.txt"')
})
it('should force attachment for unknown/unsafe content types', () => {
const response = createFileResponse({
buffer: Buffer.from('unknown content'),
contentType: 'application/unknown',
filename: 'unknown.bin',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/unknown')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="unknown.bin"'
)
})
})
describe('Content Security Policy', () => {
it('should include CSP header in all responses', () => {
const response = createFileResponse({
buffer: Buffer.from('test'),
contentType: 'text/plain',
filename: 'test.txt',
})
const csp = response.headers.get('Content-Security-Policy')
expect(csp).toBe("default-src 'none'; style-src 'unsafe-inline'; sandbox;")
})
it('should include X-Content-Type-Options header', () => {
const response = createFileResponse({
buffer: Buffer.from('test'),
contentType: 'text/plain',
filename: 'test.txt',
})
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
})
})
})
})

View File

@@ -70,7 +70,6 @@ export const contentTypeMap: Record<string, string> = {
jpg: 'image/jpeg',
jpeg: 'image/jpeg',
gif: 'image/gif',
svg: 'image/svg+xml',
// Archive formats
zip: 'application/zip',
// Folder format
@@ -153,10 +152,43 @@ export function extractBlobKey(path: string): string {
* Extract filename from a serve path
*/
export function extractFilename(path: string): string {
let filename: string
if (path.startsWith('/api/files/serve/')) {
return path.substring('/api/files/serve/'.length)
filename = path.substring('/api/files/serve/'.length)
} else {
filename = path.split('/').pop() || path
}
return path.split('/').pop() || path
filename = filename
.replace(/\.\./g, '')
.replace(/\/\.\./g, '')
.replace(/\.\.\//g, '')
// Handle cloud storage paths (s3/key, blob/key) - preserve forward slashes for these
if (filename.startsWith('s3/') || filename.startsWith('blob/')) {
// For cloud paths, only sanitize the key portion after the prefix
const parts = filename.split('/')
const prefix = parts[0] // 's3' or 'blob'
const keyParts = parts.slice(1)
// Sanitize each part of the key to prevent traversal
const sanitizedKeyParts = keyParts
.map((part) => part.replace(/\.\./g, '').replace(/^\./g, '').trim())
.filter((part) => part.length > 0)
filename = `${prefix}/${sanitizedKeyParts.join('/')}`
} else {
// For regular filenames, remove any remaining path separators
filename = filename.replace(/[/\\]/g, '')
}
// Additional validation: ensure filename is not empty after sanitization
if (!filename || filename.trim().length === 0) {
throw new Error('Invalid or empty filename after sanitization')
}
return filename
}
/**
@@ -174,16 +206,65 @@ export function findLocalFile(filename: string): string | null {
return null
}
const SAFE_INLINE_TYPES = new Set([
'image/png',
'image/jpeg',
'image/jpg',
'image/gif',
'application/pdf',
'text/plain',
'text/csv',
'application/json',
])
// File extensions that should always be served as attachment for security
const FORCE_ATTACHMENT_EXTENSIONS = new Set(['html', 'htm', 'svg', 'js', 'css', 'xml'])
/**
* Create a file response with appropriate headers
* Determines safe content type and disposition for file serving
*/
function getSecureFileHeaders(filename: string, originalContentType: string) {
const extension = filename.split('.').pop()?.toLowerCase() || ''
// Force attachment for potentially dangerous file types
if (FORCE_ATTACHMENT_EXTENSIONS.has(extension)) {
return {
contentType: 'application/octet-stream', // Force download
disposition: 'attachment',
}
}
// Override content type for safety while preserving legitimate use cases
let safeContentType = originalContentType
// Handle potentially dangerous content types
if (originalContentType === 'text/html' || originalContentType === 'image/svg+xml') {
safeContentType = 'text/plain' // Prevent browser rendering
}
// Use inline only for verified safe content types
const disposition = SAFE_INLINE_TYPES.has(safeContentType) ? 'inline' : 'attachment'
return {
contentType: safeContentType,
disposition,
}
}
/**
* Create a file response with appropriate security headers
*/
export function createFileResponse(file: FileResponse): NextResponse {
const { contentType, disposition } = getSecureFileHeaders(file.filename, file.contentType)
return new NextResponse(file.buffer as BodyInit, {
status: 200,
headers: {
'Content-Type': file.contentType,
'Content-Disposition': `inline; filename="${file.filename}"`,
'Content-Type': contentType,
'Content-Disposition': `${disposition}; filename="${file.filename}"`,
'Cache-Control': 'public, max-age=31536000', // Cache for 1 year
'X-Content-Type-Options': 'nosniff',
'Content-Security-Policy': "default-src 'none'; style-src 'unsafe-inline'; sandbox;",
},
})
}

View File

@@ -213,24 +213,81 @@ function createUserFriendlyErrorMessage(
}
/**
* Resolves environment variables and tags in code
* @param code - Code with variables
* @param params - Parameters that may contain variable values
* @param envVars - Environment variables from the workflow
* @returns Resolved code
* Resolves workflow variables with <variable.name> syntax
*/
function resolveWorkflowVariables(
code: string,
workflowVariables: Record<string, any>,
contextVariables: Record<string, any>
): string {
let resolvedCode = code
function resolveCodeVariables(
const variableMatches = resolvedCode.match(/<variable\.([^>]+)>/g) || []
for (const match of variableMatches) {
const variableName = match.slice('<variable.'.length, -1).trim()
// Find the variable by name (workflowVariables is indexed by ID, values are variable objects)
const foundVariable = Object.entries(workflowVariables).find(
([_, variable]) => (variable.name || '').replace(/\s+/g, '') === variableName
)
if (foundVariable) {
const variable = foundVariable[1]
// Get the typed value - handle different variable types
let variableValue = variable.value
if (variable.value !== undefined && variable.value !== null) {
try {
// Handle 'string' type the same as 'plain' for backward compatibility
const type = variable.type === 'string' ? 'plain' : variable.type
// For plain text, use exactly what's entered without modifications
if (type === 'plain' && typeof variableValue === 'string') {
// Use as-is for plain text
} else if (type === 'number') {
variableValue = Number(variableValue)
} else if (type === 'boolean') {
variableValue = variableValue === 'true' || variableValue === true
} else if (type === 'json') {
try {
variableValue =
typeof variableValue === 'string' ? JSON.parse(variableValue) : variableValue
} catch {
// Keep original value if JSON parsing fails
}
}
} catch (error) {
// Fallback to original value on error
variableValue = variable.value
}
}
// Create a safe variable reference
const safeVarName = `__variable_${variableName.replace(/[^a-zA-Z0-9_]/g, '_')}`
contextVariables[safeVarName] = variableValue
// Replace the variable reference with the safe variable name
resolvedCode = resolvedCode.replace(new RegExp(escapeRegExp(match), 'g'), safeVarName)
} else {
// Variable not found - replace with empty string to avoid syntax errors
resolvedCode = resolvedCode.replace(new RegExp(escapeRegExp(match), 'g'), '')
}
}
return resolvedCode
}
/**
* Resolves environment variables with {{var_name}} syntax
*/
function resolveEnvironmentVariables(
code: string,
params: Record<string, any>,
envVars: Record<string, string> = {},
blockData: Record<string, any> = {},
blockNameMapping: Record<string, string> = {}
): { resolvedCode: string; contextVariables: Record<string, any> } {
envVars: Record<string, string>,
contextVariables: Record<string, any>
): string {
let resolvedCode = code
const contextVariables: Record<string, any> = {}
// Resolve environment variables with {{var_name}} syntax
const envVarMatches = resolvedCode.match(/\{\{([^}]+)\}\}/g) || []
for (const match of envVarMatches) {
const varName = match.slice(2, -2).trim()
@@ -245,7 +302,21 @@ function resolveCodeVariables(
resolvedCode = resolvedCode.replace(new RegExp(escapeRegExp(match), 'g'), safeVarName)
}
// Resolve tags with <tag_name> syntax (including nested paths like <block.response.data>)
return resolvedCode
}
/**
* Resolves tags with <tag_name> syntax (including nested paths like <block.response.data>)
*/
function resolveTagVariables(
code: string,
params: Record<string, any>,
blockData: Record<string, any>,
blockNameMapping: Record<string, string>,
contextVariables: Record<string, any>
): string {
let resolvedCode = code
const tagMatches = resolvedCode.match(/<([a-zA-Z_][a-zA-Z0-9_.]*[a-zA-Z0-9_])>/g) || []
for (const match of tagMatches) {
@@ -300,6 +371,42 @@ function resolveCodeVariables(
resolvedCode = resolvedCode.replace(new RegExp(escapeRegExp(match), 'g'), safeVarName)
}
return resolvedCode
}
/**
* Resolves environment variables and tags in code
* @param code - Code with variables
* @param params - Parameters that may contain variable values
* @param envVars - Environment variables from the workflow
* @returns Resolved code
*/
function resolveCodeVariables(
code: string,
params: Record<string, any>,
envVars: Record<string, string> = {},
blockData: Record<string, any> = {},
blockNameMapping: Record<string, string> = {},
workflowVariables: Record<string, any> = {}
): { resolvedCode: string; contextVariables: Record<string, any> } {
let resolvedCode = code
const contextVariables: Record<string, any> = {}
// Resolve workflow variables with <variable.name> syntax first
resolvedCode = resolveWorkflowVariables(resolvedCode, workflowVariables, contextVariables)
// Resolve environment variables with {{var_name}} syntax
resolvedCode = resolveEnvironmentVariables(resolvedCode, params, envVars, contextVariables)
// Resolve tags with <tag_name> syntax (including nested paths like <block.response.data>)
resolvedCode = resolveTagVariables(
resolvedCode,
params,
blockData,
blockNameMapping,
contextVariables
)
return { resolvedCode, contextVariables }
}
@@ -338,6 +445,7 @@ export async function POST(req: NextRequest) {
envVars = {},
blockData = {},
blockNameMapping = {},
workflowVariables = {},
workflowId,
isCustomTool = false,
} = body
@@ -360,7 +468,8 @@ export async function POST(req: NextRequest) {
executionParams,
envVars,
blockData,
blockNameMapping
blockNameMapping,
workflowVariables
)
resolvedCode = codeResolution.resolvedCode
const contextVariables = codeResolution.contextVariables
@@ -368,8 +477,8 @@ export async function POST(req: NextRequest) {
const executionMethod = 'vm' // Default execution method
logger.info(`[${requestId}] Using VM for code execution`, {
resolvedCode,
hasEnvVars: Object.keys(envVars).length > 0,
hasWorkflowVariables: Object.keys(workflowVariables).length > 0,
})
// Create a secure context with console logging

View File

@@ -1,12 +1,10 @@
import { createHash, randomUUID } from 'crypto'
import { eq, sql } from 'drizzle-orm'
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { deleteChunk, updateChunk } from '@/lib/knowledge/chunks/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkChunkAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, embedding } from '@/db/schema'
const logger = createLogger('ChunkByIdAPI')
@@ -102,33 +100,7 @@ export async function PUT(
try {
const validatedData = UpdateChunkSchema.parse(body)
const updateData: Partial<{
content: string
contentLength: number
tokenCount: number
chunkHash: string
enabled: boolean
updatedAt: Date
}> = {}
if (validatedData.content) {
updateData.content = validatedData.content
updateData.contentLength = validatedData.content.length
// Update token count estimation (rough approximation: 4 chars per token)
updateData.tokenCount = Math.ceil(validatedData.content.length / 4)
updateData.chunkHash = createHash('sha256').update(validatedData.content).digest('hex')
}
if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled
await db.update(embedding).set(updateData).where(eq(embedding.id, chunkId))
// Fetch the updated chunk
const updatedChunk = await db
.select()
.from(embedding)
.where(eq(embedding.id, chunkId))
.limit(1)
const updatedChunk = await updateChunk(chunkId, validatedData, requestId)
logger.info(
`[${requestId}] Chunk updated: ${chunkId} in document ${documentId} in knowledge base ${knowledgeBaseId}`
@@ -136,7 +108,7 @@ export async function PUT(
return NextResponse.json({
success: true,
data: updatedChunk[0],
data: updatedChunk,
})
} catch (validationError) {
if (validationError instanceof z.ZodError) {
@@ -190,37 +162,7 @@ export async function DELETE(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Use transaction to atomically delete chunk and update document statistics
await db.transaction(async (tx) => {
// Get chunk data before deletion for statistics update
const chunkToDelete = await tx
.select({
tokenCount: embedding.tokenCount,
contentLength: embedding.contentLength,
})
.from(embedding)
.where(eq(embedding.id, chunkId))
.limit(1)
if (chunkToDelete.length === 0) {
throw new Error('Chunk not found')
}
const chunk = chunkToDelete[0]
// Delete the chunk
await tx.delete(embedding).where(eq(embedding.id, chunkId))
// Update document statistics
await tx
.update(document)
.set({
chunkCount: sql`${document.chunkCount} - 1`,
tokenCount: sql`${document.tokenCount} - ${chunk.tokenCount}`,
characterCount: sql`${document.characterCount} - ${chunk.contentLength}`,
})
.where(eq(document.id, documentId))
})
await deleteChunk(chunkId, documentId, requestId)
logger.info(
`[${requestId}] Chunk deleted: ${chunkId} from document ${documentId} in knowledge base ${knowledgeBaseId}`

View File

@@ -1,378 +0,0 @@
/**
* Tests for knowledge document chunks API route
*
* @vitest-environment node
*/
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockRequest,
mockAuth,
mockConsoleLogger,
mockDrizzleOrm,
mockKnowledgeSchemas,
} from '@/app/api/__test-utils__/utils'
mockKnowledgeSchemas()
mockDrizzleOrm()
mockConsoleLogger()
vi.mock('@/lib/tokenization/estimators', () => ({
estimateTokenCount: vi.fn().mockReturnValue({ count: 452 }),
}))
vi.mock('@/providers/utils', () => ({
calculateCost: vi.fn().mockReturnValue({
input: 0.00000904,
output: 0,
total: 0.00000904,
pricing: {
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
},
}),
}))
vi.mock('@/app/api/knowledge/utils', () => ({
checkKnowledgeBaseAccess: vi.fn(),
checkKnowledgeBaseWriteAccess: vi.fn(),
checkDocumentAccess: vi.fn(),
checkDocumentWriteAccess: vi.fn(),
checkChunkAccess: vi.fn(),
generateEmbeddings: vi.fn().mockResolvedValue([[0.1, 0.2, 0.3, 0.4, 0.5]]),
processDocumentAsync: vi.fn(),
}))
describe('Knowledge Document Chunks API Route', () => {
const mockAuth$ = mockAuth()
const mockDbChain = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockReturnThis(),
offset: vi.fn().mockReturnThis(),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockResolvedValue(undefined),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
returning: vi.fn().mockResolvedValue([]),
delete: vi.fn().mockReturnThis(),
transaction: vi.fn(),
}
const mockGetUserId = vi.fn()
beforeEach(async () => {
vi.clearAllMocks()
vi.doMock('@/db', () => ({
db: mockDbChain,
}))
vi.doMock('@/app/api/auth/oauth/utils', () => ({
getUserId: mockGetUserId,
}))
Object.values(mockDbChain).forEach((fn) => {
if (typeof fn === 'function' && fn !== mockDbChain.values && fn !== mockDbChain.returning) {
fn.mockClear().mockReturnThis()
}
})
vi.stubGlobal('crypto', {
randomUUID: vi.fn().mockReturnValue('mock-chunk-uuid-1234'),
createHash: vi.fn().mockReturnValue({
update: vi.fn().mockReturnThis(),
digest: vi.fn().mockReturnValue('mock-hash-123'),
}),
})
})
afterEach(() => {
vi.clearAllMocks()
})
describe('POST /api/knowledge/[id]/documents/[documentId]/chunks', () => {
const validChunkData = {
content: 'This is test chunk content for uploading to the knowledge base document.',
enabled: true,
}
const mockDocumentAccess = {
hasAccess: true,
notFound: false,
reason: '',
document: {
id: 'doc-123',
processingStatus: 'completed',
tag1: 'tag1-value',
tag2: 'tag2-value',
tag3: null,
tag4: null,
tag5: null,
tag6: null,
tag7: null,
},
}
const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' })
it('should create chunk successfully with cost tracking', async () => {
const { checkDocumentWriteAccess, generateEmbeddings } = await import(
'@/app/api/knowledge/utils'
)
const { estimateTokenCount } = await import('@/lib/tokenization/estimators')
const { calculateCost } = await import('@/providers/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
...mockDocumentAccess,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
} as any)
// Mock generateEmbeddings
vi.mocked(generateEmbeddings).mockResolvedValue([[0.1, 0.2, 0.3]])
// Mock transaction
const mockTx = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockResolvedValue([{ chunkIndex: 0 }]),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockResolvedValue(undefined),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
}
mockDbChain.transaction.mockImplementation(async (callback) => {
return await callback(mockTx)
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(200)
expect(data.success).toBe(true)
// Verify cost tracking
expect(data.data.cost).toBeDefined()
expect(data.data.cost.input).toBe(0.00000904)
expect(data.data.cost.output).toBe(0)
expect(data.data.cost.total).toBe(0.00000904)
expect(data.data.cost.tokens).toEqual({
prompt: 452,
completion: 0,
total: 452,
})
expect(data.data.cost.model).toBe('text-embedding-3-small')
expect(data.data.cost.pricing).toEqual({
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
})
// Verify function calls
expect(estimateTokenCount).toHaveBeenCalledWith(validChunkData.content, 'openai')
expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 452, 0, false)
})
it('should handle workflow-based authentication', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const workflowData = {
...validChunkData,
workflowId: 'workflow-123',
}
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
...mockDocumentAccess,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
} as any)
const mockTx = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockResolvedValue([]),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockResolvedValue(undefined),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
}
mockDbChain.transaction.mockImplementation(async (callback) => {
return await callback(mockTx)
})
const req = createMockRequest('POST', workflowData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(mockGetUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123')
})
it.concurrent('should return unauthorized for unauthenticated request', async () => {
mockGetUserId.mockResolvedValue(null)
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(401)
expect(data.error).toBe('Unauthorized')
})
it('should return not found for workflow that does not exist', async () => {
const workflowData = {
...validChunkData,
workflowId: 'nonexistent-workflow',
}
mockGetUserId.mockResolvedValue(null)
const req = createMockRequest('POST', workflowData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(404)
expect(data.error).toBe('Workflow not found')
})
it.concurrent('should return not found for document access denied', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
hasAccess: false,
notFound: true,
reason: 'Document not found',
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(404)
expect(data.error).toBe('Document not found')
})
it('should return unauthorized for unauthorized document access', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
hasAccess: false,
notFound: false,
reason: 'Unauthorized access',
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(401)
expect(data.error).toBe('Unauthorized')
})
it('should reject chunks for failed documents', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
...mockDocumentAccess,
document: {
...mockDocumentAccess.document!,
processingStatus: 'failed',
},
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
} as any)
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(400)
expect(data.error).toBe('Cannot add chunks to failed document')
})
it.concurrent('should validate chunk data', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
...mockDocumentAccess,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
} as any)
const invalidData = {
content: '', // Empty content
enabled: true,
}
const req = createMockRequest('POST', invalidData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(400)
expect(data.error).toBe('Invalid request data')
expect(data.details).toBeDefined()
})
it('should inherit tags from parent document', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
...mockDocumentAccess,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
} as any)
const mockTx = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockResolvedValue([]),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockImplementation((data) => {
// Verify that tags are inherited from document
expect(data.tag1).toBe('tag1-value')
expect(data.tag2).toBe('tag2-value')
expect(data.tag3).toBe(null)
return Promise.resolve(undefined)
}),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
}
mockDbChain.transaction.mockImplementation(async (callback) => {
return await callback(mockTx)
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
await POST(req, { params: mockParams })
expect(mockTx.values).toHaveBeenCalled()
})
// REMOVED: "should handle cost calculation with different content lengths" test - it was failing
})
})

View File

@@ -1,18 +1,11 @@
import crypto from 'crypto'
import { and, asc, eq, ilike, inArray, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { batchChunkOperation, createChunk, queryChunks } from '@/lib/knowledge/chunks/service'
import { createLogger } from '@/lib/logs/console/logger'
import { estimateTokenCount } from '@/lib/tokenization/estimators'
import { getUserId } from '@/app/api/auth/oauth/utils'
import {
checkDocumentAccess,
checkDocumentWriteAccess,
generateEmbeddings,
} from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, embedding } from '@/db/schema'
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
import { calculateCost } from '@/providers/utils'
const logger = createLogger('DocumentChunksAPI')
@@ -66,7 +59,6 @@ export async function GET(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if document processing is completed
const doc = accessCheck.document
if (!doc) {
logger.warn(
@@ -89,7 +81,6 @@ export async function GET(
)
}
// Parse query parameters
const { searchParams } = new URL(req.url)
const queryParams = GetChunksQuerySchema.parse({
search: searchParams.get('search') || undefined,
@@ -98,67 +89,12 @@ export async function GET(
offset: searchParams.get('offset') || undefined,
})
// Build query conditions
const conditions = [eq(embedding.documentId, documentId)]
// Add enabled filter
if (queryParams.enabled === 'true') {
conditions.push(eq(embedding.enabled, true))
} else if (queryParams.enabled === 'false') {
conditions.push(eq(embedding.enabled, false))
}
// Add search filter
if (queryParams.search) {
conditions.push(ilike(embedding.content, `%${queryParams.search}%`))
}
// Fetch chunks
const chunks = await db
.select({
id: embedding.id,
chunkIndex: embedding.chunkIndex,
content: embedding.content,
contentLength: embedding.contentLength,
tokenCount: embedding.tokenCount,
enabled: embedding.enabled,
startOffset: embedding.startOffset,
endOffset: embedding.endOffset,
tag1: embedding.tag1,
tag2: embedding.tag2,
tag3: embedding.tag3,
tag4: embedding.tag4,
tag5: embedding.tag5,
tag6: embedding.tag6,
tag7: embedding.tag7,
createdAt: embedding.createdAt,
updatedAt: embedding.updatedAt,
})
.from(embedding)
.where(and(...conditions))
.orderBy(asc(embedding.chunkIndex))
.limit(queryParams.limit)
.offset(queryParams.offset)
// Get total count for pagination
const totalCount = await db
.select({ count: sql`count(*)` })
.from(embedding)
.where(and(...conditions))
logger.info(
`[${requestId}] Retrieved ${chunks.length} chunks for document ${documentId} in knowledge base ${knowledgeBaseId}`
)
const result = await queryChunks(documentId, queryParams, requestId)
return NextResponse.json({
success: true,
data: chunks,
pagination: {
total: Number(totalCount[0]?.count || 0),
limit: queryParams.limit,
offset: queryParams.offset,
hasMore: chunks.length === queryParams.limit,
},
data: result.chunks,
pagination: result.pagination,
})
} catch (error) {
logger.error(`[${requestId}] Error fetching chunks`, error)
@@ -219,76 +155,27 @@ export async function POST(
try {
const validatedData = CreateChunkSchema.parse(searchParams)
// Generate embedding for the content first (outside transaction for performance)
logger.info(`[${requestId}] Generating embedding for manual chunk`)
const embeddings = await generateEmbeddings([validatedData.content])
const docTags = {
tag1: doc.tag1 ?? null,
tag2: doc.tag2 ?? null,
tag3: doc.tag3 ?? null,
tag4: doc.tag4 ?? null,
tag5: doc.tag5 ?? null,
tag6: doc.tag6 ?? null,
tag7: doc.tag7 ?? null,
}
// Calculate accurate token count for both database storage and cost calculation
const tokenCount = estimateTokenCount(validatedData.content, 'openai')
const newChunk = await createChunk(
knowledgeBaseId,
documentId,
docTags,
validatedData,
requestId
)
const chunkId = crypto.randomUUID()
const now = new Date()
// Use transaction to atomically get next index and insert chunk
const newChunk = await db.transaction(async (tx) => {
// Get the next chunk index atomically within the transaction
const lastChunk = await tx
.select({ chunkIndex: embedding.chunkIndex })
.from(embedding)
.where(eq(embedding.documentId, documentId))
.orderBy(sql`${embedding.chunkIndex} DESC`)
.limit(1)
const nextChunkIndex = lastChunk.length > 0 ? lastChunk[0].chunkIndex + 1 : 0
const chunkData = {
id: chunkId,
knowledgeBaseId,
documentId,
chunkIndex: nextChunkIndex,
chunkHash: crypto.createHash('sha256').update(validatedData.content).digest('hex'),
content: validatedData.content,
contentLength: validatedData.content.length,
tokenCount: tokenCount.count, // Use accurate token count
embedding: embeddings[0],
embeddingModel: 'text-embedding-3-small',
startOffset: 0, // Manual chunks don't have document offsets
endOffset: validatedData.content.length,
// Inherit tags from parent document
tag1: doc.tag1,
tag2: doc.tag2,
tag3: doc.tag3,
tag4: doc.tag4,
tag5: doc.tag5,
tag6: doc.tag6,
tag7: doc.tag7,
enabled: validatedData.enabled,
createdAt: now,
updatedAt: now,
}
// Insert the new chunk
await tx.insert(embedding).values(chunkData)
// Update document statistics
await tx
.update(document)
.set({
chunkCount: sql`${document.chunkCount} + 1`,
tokenCount: sql`${document.tokenCount} + ${chunkData.tokenCount}`,
characterCount: sql`${document.characterCount} + ${chunkData.contentLength}`,
})
.where(eq(document.id, documentId))
return chunkData
})
logger.info(`[${requestId}] Manual chunk created: ${chunkId} in document ${documentId}`)
// Calculate cost for the embedding (with fallback if calculation fails)
let cost = null
try {
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
cost = calculateCost('text-embedding-3-small', newChunk.tokenCount, 0, false)
} catch (error) {
logger.warn(`[${requestId}] Failed to calculate cost for chunk upload`, {
error: error instanceof Error ? error.message : 'Unknown error',
@@ -300,6 +187,8 @@ export async function POST(
success: true,
data: {
...newChunk,
documentId,
documentName: doc.filename,
...(cost
? {
cost: {
@@ -307,9 +196,9 @@ export async function POST(
output: cost.output,
total: cost.total,
tokens: {
prompt: tokenCount.count,
prompt: newChunk.tokenCount,
completion: 0,
total: tokenCount.count,
total: newChunk.tokenCount,
},
model: 'text-embedding-3-small',
pricing: cost.pricing,
@@ -371,92 +260,16 @@ export async function PATCH(
const validatedData = BatchOperationSchema.parse(body)
const { operation, chunkIds } = validatedData
logger.info(
`[${requestId}] Starting batch ${operation} operation on ${chunkIds.length} chunks for document ${documentId}`
)
const results = []
let successCount = 0
const errorCount = 0
if (operation === 'delete') {
// Handle batch delete with transaction for consistency
await db.transaction(async (tx) => {
// Get chunks to delete for statistics update
const chunksToDelete = await tx
.select({
id: embedding.id,
tokenCount: embedding.tokenCount,
contentLength: embedding.contentLength,
})
.from(embedding)
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
if (chunksToDelete.length === 0) {
throw new Error('No valid chunks found to delete')
}
// Delete chunks
await tx
.delete(embedding)
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
// Update document statistics
const totalTokens = chunksToDelete.reduce((sum, chunk) => sum + chunk.tokenCount, 0)
const totalCharacters = chunksToDelete.reduce(
(sum, chunk) => sum + chunk.contentLength,
0
)
await tx
.update(document)
.set({
chunkCount: sql`${document.chunkCount} - ${chunksToDelete.length}`,
tokenCount: sql`${document.tokenCount} - ${totalTokens}`,
characterCount: sql`${document.characterCount} - ${totalCharacters}`,
})
.where(eq(document.id, documentId))
successCount = chunksToDelete.length
results.push({
operation: 'delete',
deletedCount: chunksToDelete.length,
chunkIds: chunksToDelete.map((c) => c.id),
})
})
} else {
// Handle batch enable/disable
const enabled = operation === 'enable'
// Update chunks in a single query
const updateResult = await db
.update(embedding)
.set({
enabled,
updatedAt: new Date(),
})
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
.returning({ id: embedding.id })
successCount = updateResult.length
results.push({
operation,
updatedCount: updateResult.length,
chunkIds: updateResult.map((r) => r.id),
})
}
logger.info(
`[${requestId}] Batch ${operation} operation completed: ${successCount} successful, ${errorCount} errors`
)
const result = await batchChunkOperation(documentId, operation, chunkIds, requestId)
return NextResponse.json({
success: true,
data: {
operation,
successCount,
errorCount,
results,
successCount: result.processed,
errorCount: result.errors.length,
processed: result.processed,
errors: result.errors,
},
})
} catch (validationError) {

View File

@@ -24,7 +24,14 @@ vi.mock('@/app/api/knowledge/utils', () => ({
processDocumentAsync: vi.fn(),
}))
// Setup common mocks
vi.mock('@/lib/knowledge/documents/service', () => ({
updateDocument: vi.fn(),
deleteDocument: vi.fn(),
markDocumentAsFailedTimeout: vi.fn(),
retryDocumentProcessing: vi.fn(),
processDocumentAsync: vi.fn(),
}))
mockDrizzleOrm()
mockConsoleLogger()
@@ -42,8 +49,6 @@ describe('Document By ID API Route', () => {
transaction: vi.fn(),
}
// Mock functions will be imported dynamically in tests
const mockDocument = {
id: 'doc-123',
knowledgeBaseId: 'kb-123',
@@ -73,7 +78,6 @@ describe('Document By ID API Route', () => {
}
}
})
// Mock functions are cleared automatically by vitest
}
beforeEach(async () => {
@@ -83,8 +87,6 @@ describe('Document By ID API Route', () => {
db: mockDbChain,
}))
// Utils are mocked at the top level
vi.stubGlobal('crypto', {
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),
})
@@ -195,6 +197,7 @@ describe('Document By ID API Route', () => {
it('should update document successfully', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { updateDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
@@ -203,31 +206,12 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Create a sequence of mocks for the database operations
const updateChain = {
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue(undefined), // Update operation completes
}),
const updatedDocument = {
...mockDocument,
...validUpdateData,
deletedAt: null,
}
const selectChain = {
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ ...mockDocument, ...validUpdateData }]),
}),
}),
}
// Mock transaction
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
update: vi.fn().mockReturnValue(updateChain),
}
await callback(mockTx)
})
// Mock db operations in sequence
mockDbChain.select.mockReturnValue(selectChain)
vi.mocked(updateDocument).mockResolvedValue(updatedDocument)
const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
@@ -238,8 +222,11 @@ describe('Document By ID API Route', () => {
expect(data.success).toBe(true)
expect(data.data.filename).toBe('updated-document.pdf')
expect(data.data.enabled).toBe(false)
expect(mockDbChain.transaction).toHaveBeenCalled()
expect(mockDbChain.select).toHaveBeenCalled()
expect(vi.mocked(updateDocument)).toHaveBeenCalledWith(
'doc-123',
validUpdateData,
expect.any(String)
)
})
it('should validate update data', async () => {
@@ -274,6 +261,7 @@ describe('Document By ID API Route', () => {
it('should mark document as failed due to timeout successfully', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { markDocumentAsFailedTimeout } = await import('@/lib/knowledge/documents/service')
const processingDocument = {
...mockDocument,
@@ -288,34 +276,11 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Create a sequence of mocks for the database operations
const updateChain = {
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue(undefined), // Update operation completes
}),
}
const selectChain = {
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi
.fn()
.mockResolvedValue([{ ...processingDocument, processingStatus: 'failed' }]),
}),
}),
}
// Mock transaction
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
update: vi.fn().mockReturnValue(updateChain),
}
await callback(mockTx)
vi.mocked(markDocumentAsFailedTimeout).mockResolvedValue({
success: true,
processingDuration: 200000,
})
// Mock db operations in sequence
mockDbChain.select.mockReturnValue(selectChain)
const req = createMockRequest('PUT', { markFailedDueToTimeout: true })
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
const response = await PUT(req, { params: mockParams })
@@ -323,13 +288,13 @@ describe('Document By ID API Route', () => {
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(mockDbChain.transaction).toHaveBeenCalled()
expect(updateChain.set).toHaveBeenCalledWith(
expect.objectContaining({
processingStatus: 'failed',
processingError: 'Processing timed out - background process may have been terminated',
processingCompletedAt: expect.any(Date),
})
expect(data.data.documentId).toBe('doc-123')
expect(data.data.status).toBe('failed')
expect(data.data.message).toBe('Document marked as failed due to timeout')
expect(vi.mocked(markDocumentAsFailedTimeout)).toHaveBeenCalledWith(
'doc-123',
processingDocument.processingStartedAt,
expect.any(String)
)
})
@@ -354,6 +319,7 @@ describe('Document By ID API Route', () => {
it('should reject marking failed for recently started processing', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { markDocumentAsFailedTimeout } = await import('@/lib/knowledge/documents/service')
const recentProcessingDocument = {
...mockDocument,
@@ -368,6 +334,10 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
vi.mocked(markDocumentAsFailedTimeout).mockRejectedValue(
new Error('Document has not been processing long enough to be considered dead')
)
const req = createMockRequest('PUT', { markFailedDueToTimeout: true })
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
const response = await PUT(req, { params: mockParams })
@@ -382,9 +352,8 @@ describe('Document By ID API Route', () => {
const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' })
it('should retry processing successfully', async () => {
const { checkDocumentWriteAccess, processDocumentAsync } = await import(
'@/app/api/knowledge/utils'
)
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { retryDocumentProcessing } = await import('@/lib/knowledge/documents/service')
const failedDocument = {
...mockDocument,
@@ -399,23 +368,12 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock transaction
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
delete: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue(undefined),
}),
update: vi.fn().mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue(undefined),
}),
}),
}
return await callback(mockTx)
vi.mocked(retryDocumentProcessing).mockResolvedValue({
success: true,
status: 'pending',
message: 'Document retry processing started',
})
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
const req = createMockRequest('PUT', { retryProcessing: true })
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
const response = await PUT(req, { params: mockParams })
@@ -425,8 +383,17 @@ describe('Document By ID API Route', () => {
expect(data.success).toBe(true)
expect(data.data.status).toBe('pending')
expect(data.data.message).toBe('Document retry processing started')
expect(mockDbChain.transaction).toHaveBeenCalled()
expect(vi.mocked(processDocumentAsync)).toHaveBeenCalled()
expect(vi.mocked(retryDocumentProcessing)).toHaveBeenCalledWith(
'kb-123',
'doc-123',
{
filename: failedDocument.filename,
fileUrl: failedDocument.fileUrl,
fileSize: failedDocument.fileSize,
mimeType: failedDocument.mimeType,
},
expect.any(String)
)
})
it('should reject retry for non-failed document', async () => {
@@ -486,6 +453,7 @@ describe('Document By ID API Route', () => {
it('should handle database errors during update', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { updateDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
@@ -494,8 +462,7 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock transaction to throw an error
mockDbChain.transaction.mockRejectedValue(new Error('Database error'))
vi.mocked(updateDocument).mockRejectedValue(new Error('Database error'))
const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
@@ -512,6 +479,7 @@ describe('Document By ID API Route', () => {
it('should delete document successfully', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { deleteDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
@@ -520,10 +488,10 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Properly chain the mock database operations for soft delete
mockDbChain.update.mockReturnValue(mockDbChain)
mockDbChain.set.mockReturnValue(mockDbChain)
mockDbChain.where.mockResolvedValue(undefined) // Update operation resolves
vi.mocked(deleteDocument).mockResolvedValue({
success: true,
message: 'Document deleted successfully',
})
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
@@ -533,12 +501,7 @@ describe('Document By ID API Route', () => {
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(data.data.message).toBe('Document deleted successfully')
expect(mockDbChain.update).toHaveBeenCalled()
expect(mockDbChain.set).toHaveBeenCalledWith(
expect.objectContaining({
deletedAt: expect.any(Date),
})
)
expect(vi.mocked(deleteDocument)).toHaveBeenCalledWith('doc-123', expect.any(String))
})
it('should return unauthorized for unauthenticated user', async () => {
@@ -592,6 +555,7 @@ describe('Document By ID API Route', () => {
it('should handle database errors during deletion', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { deleteDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
@@ -599,7 +563,7 @@ describe('Document By ID API Route', () => {
document: mockDocument,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.set.mockRejectedValue(new Error('Database error'))
vi.mocked(deleteDocument).mockRejectedValue(new Error('Database error'))
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')

View File

@@ -1,16 +1,14 @@
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { TAG_SLOTS } from '@/lib/constants/knowledge'
import { createLogger } from '@/lib/logs/console/logger'
import {
checkDocumentAccess,
checkDocumentWriteAccess,
processDocumentAsync,
} from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, embedding } from '@/db/schema'
deleteDocument,
markDocumentAsFailedTimeout,
retryDocumentProcessing,
updateDocument,
} from '@/lib/knowledge/documents/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
const logger = createLogger('DocumentByIdAPI')
@@ -113,9 +111,7 @@ export async function PUT(
const updateData: any = {}
// Handle special operations first
if (validatedData.markFailedDueToTimeout) {
// Mark document as failed due to timeout (replaces mark-failed endpoint)
const doc = accessCheck.document
if (doc.processingStatus !== 'processing') {
@@ -132,58 +128,30 @@ export async function PUT(
)
}
const now = new Date()
const processingDuration = now.getTime() - new Date(doc.processingStartedAt).getTime()
const DEAD_PROCESS_THRESHOLD_MS = 150 * 1000
try {
await markDocumentAsFailedTimeout(documentId, doc.processingStartedAt, requestId)
if (processingDuration <= DEAD_PROCESS_THRESHOLD_MS) {
return NextResponse.json(
{ error: 'Document has not been processing long enough to be considered dead' },
{ status: 400 }
)
return NextResponse.json({
success: true,
data: {
documentId,
status: 'failed',
message: 'Document marked as failed due to timeout',
},
})
} catch (error) {
if (error instanceof Error) {
return NextResponse.json({ error: error.message }, { status: 400 })
}
throw error
}
updateData.processingStatus = 'failed'
updateData.processingError =
'Processing timed out - background process may have been terminated'
updateData.processingCompletedAt = now
logger.info(
`[${requestId}] Marked document ${documentId} as failed due to dead process (processing time: ${Math.round(processingDuration / 1000)}s)`
)
} else if (validatedData.retryProcessing) {
// Retry processing (replaces retry endpoint)
const doc = accessCheck.document
if (doc.processingStatus !== 'failed') {
return NextResponse.json({ error: 'Document is not in failed state' }, { status: 400 })
}
// Clear existing embeddings and reset document state
await db.transaction(async (tx) => {
await tx.delete(embedding).where(eq(embedding.documentId, documentId))
await tx
.update(document)
.set({
processingStatus: 'pending',
processingStartedAt: null,
processingCompletedAt: null,
processingError: null,
chunkCount: 0,
tokenCount: 0,
characterCount: 0,
})
.where(eq(document.id, documentId))
})
const processingOptions = {
chunkSize: 1024,
minCharactersPerChunk: 24,
recipe: 'default',
lang: 'en',
}
const docData = {
filename: doc.filename,
fileUrl: doc.fileUrl,
@@ -191,80 +159,33 @@ export async function PUT(
mimeType: doc.mimeType,
}
processDocumentAsync(knowledgeBaseId, documentId, docData, processingOptions).catch(
(error: unknown) => {
logger.error(`[${requestId}] Background retry processing error:`, error)
}
const result = await retryDocumentProcessing(
knowledgeBaseId,
documentId,
docData,
requestId
)
logger.info(`[${requestId}] Document retry initiated: ${documentId}`)
return NextResponse.json({
success: true,
data: {
documentId,
status: 'pending',
message: 'Document retry processing started',
status: result.status,
message: result.message,
},
})
} else {
// Regular field updates
if (validatedData.filename !== undefined) updateData.filename = validatedData.filename
if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled
if (validatedData.chunkCount !== undefined) updateData.chunkCount = validatedData.chunkCount
if (validatedData.tokenCount !== undefined) updateData.tokenCount = validatedData.tokenCount
if (validatedData.characterCount !== undefined)
updateData.characterCount = validatedData.characterCount
if (validatedData.processingStatus !== undefined)
updateData.processingStatus = validatedData.processingStatus
if (validatedData.processingError !== undefined)
updateData.processingError = validatedData.processingError
const updatedDocument = await updateDocument(documentId, validatedData, requestId)
// Tag field updates
TAG_SLOTS.forEach((slot) => {
if ((validatedData as any)[slot] !== undefined) {
;(updateData as any)[slot] = (validatedData as any)[slot]
}
logger.info(
`[${requestId}] Document updated: ${documentId} in knowledge base ${knowledgeBaseId}`
)
return NextResponse.json({
success: true,
data: updatedDocument,
})
}
await db.transaction(async (tx) => {
// Update the document
await tx.update(document).set(updateData).where(eq(document.id, documentId))
// If any tag fields were updated, also update the embeddings
const hasTagUpdates = TAG_SLOTS.some((field) => (validatedData as any)[field] !== undefined)
if (hasTagUpdates) {
const embeddingUpdateData: Record<string, string | null> = {}
TAG_SLOTS.forEach((field) => {
if ((validatedData as any)[field] !== undefined) {
embeddingUpdateData[field] = (validatedData as any)[field] || null
}
})
await tx
.update(embedding)
.set(embeddingUpdateData)
.where(eq(embedding.documentId, documentId))
}
})
// Fetch the updated document
const updatedDocument = await db
.select()
.from(document)
.where(eq(document.id, documentId))
.limit(1)
logger.info(
`[${requestId}] Document updated: ${documentId} in knowledge base ${knowledgeBaseId}`
)
return NextResponse.json({
success: true,
data: updatedDocument[0],
})
} catch (validationError) {
if (validationError instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid document update data`, {
@@ -313,13 +234,7 @@ export async function DELETE(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Soft delete by setting deletedAt timestamp
await db
.update(document)
.set({
deletedAt: new Date(),
})
.where(eq(document.id, documentId))
const result = await deleteDocument(documentId, requestId)
logger.info(
`[${requestId}] Document deleted: ${documentId} from knowledge base ${knowledgeBaseId}`
@@ -327,7 +242,7 @@ export async function DELETE(
return NextResponse.json({
success: true,
data: { message: 'Document deleted successfully' },
data: result,
})
} catch (error) {
logger.error(`[${requestId}] Error deleting document`, error)

View File

@@ -1,17 +1,17 @@
import { randomUUID } from 'crypto'
import { and, eq, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { SUPPORTED_FIELD_TYPES } from '@/lib/constants/knowledge'
import {
getMaxSlotsForFieldType,
getSlotsForFieldType,
SUPPORTED_FIELD_TYPES,
} from '@/lib/constants/knowledge'
cleanupUnusedTagDefinitions,
createOrUpdateTagDefinitionsBulk,
deleteAllTagDefinitions,
getDocumentTagDefinitions,
} from '@/lib/knowledge/tags/service'
import type { BulkTagDefinitionsData } from '@/lib/knowledge/tags/types'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
export const dynamic = 'force-dynamic'
@@ -29,106 +29,6 @@ const BulkTagDefinitionsSchema = z.object({
definitions: z.array(TagDefinitionSchema),
})
// Helper function to get the next available slot for a knowledge base and field type
async function getNextAvailableSlot(
knowledgeBaseId: string,
fieldType: string,
existingBySlot?: Map<string, any>
): Promise<string | null> {
// Get available slots for this field type
const availableSlots = getSlotsForFieldType(fieldType)
let usedSlots: Set<string>
if (existingBySlot) {
// Use provided map if available (for performance in batch operations)
// Filter by field type
usedSlots = new Set(
Array.from(existingBySlot.entries())
.filter(([_, def]) => def.fieldType === fieldType)
.map(([slot, _]) => slot)
)
} else {
// Query database for existing tag definitions of the same field type
const existingDefinitions = await db
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
)
)
usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot))
}
// Find the first available slot for this field type
for (const slot of availableSlots) {
if (!usedSlots.has(slot)) {
return slot
}
}
return null // No available slots for this field type
}
// Helper function to clean up unused tag definitions
async function cleanupUnusedTagDefinitions(knowledgeBaseId: string, requestId: string) {
try {
logger.info(`[${requestId}] Starting cleanup for KB ${knowledgeBaseId}`)
// Get all tag definitions for this KB
const allDefinitions = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
logger.info(`[${requestId}] Found ${allDefinitions.length} tag definitions to check`)
if (allDefinitions.length === 0) {
return 0
}
let cleanedCount = 0
// For each tag definition, check if any documents use that tag slot
for (const definition of allDefinitions) {
const slot = definition.tagSlot
// Use raw SQL with proper column name injection
const countResult = await db.execute(sql`
SELECT count(*) as count
FROM document
WHERE knowledge_base_id = ${knowledgeBaseId}
AND ${sql.raw(slot)} IS NOT NULL
AND trim(${sql.raw(slot)}) != ''
`)
const count = Number(countResult[0]?.count) || 0
logger.info(
`[${requestId}] Tag ${definition.displayName} (${slot}): ${count} documents using it`
)
// If count is 0, remove this tag definition
if (count === 0) {
await db
.delete(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.id, definition.id))
cleanedCount++
logger.info(
`[${requestId}] Removed unused tag definition: ${definition.displayName} (${definition.tagSlot})`
)
}
}
return cleanedCount
} catch (error) {
logger.warn(`[${requestId}] Failed to cleanup unused tag definitions:`, error)
return 0 // Don't fail the main operation if cleanup fails
}
}
// GET /api/knowledge/[id]/documents/[documentId]/tag-definitions - Get tag definitions for a document
export async function GET(
req: NextRequest,
@@ -145,35 +45,22 @@ export async function GET(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Verify document exists and belongs to the knowledge base
const documentExists = await db
.select({ id: document.id })
.from(document)
.where(and(eq(document.id, documentId), eq(document.knowledgeBaseId, knowledgeBaseId)))
.limit(1)
if (documentExists.length === 0) {
return NextResponse.json({ error: 'Document not found' }, { status: 404 })
const accessCheck = await checkDocumentAccess(knowledgeBaseId, documentId, session.user.id)
if (!accessCheck.hasAccess) {
if (accessCheck.notFound) {
logger.warn(
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
)
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
}
logger.warn(
`[${requestId}] User ${session.user.id} attempted unauthorized document access: ${accessCheck.reason}`
)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Get tag definitions for the knowledge base
const tagDefinitions = await db
.select({
id: knowledgeBaseTagDefinitions.id,
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
fieldType: knowledgeBaseTagDefinitions.fieldType,
createdAt: knowledgeBaseTagDefinitions.createdAt,
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
})
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
const tagDefinitions = await getDocumentTagDefinitions(knowledgeBaseId)
logger.info(`[${requestId}] Retrieved ${tagDefinitions.length} tag definitions`)
@@ -203,21 +90,19 @@ export async function POST(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has write access to the knowledge base
const accessCheck = await checkKnowledgeBaseWriteAccess(knowledgeBaseId, session.user.id)
// Verify document exists and user has write access
const accessCheck = await checkDocumentWriteAccess(knowledgeBaseId, documentId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Verify document exists and belongs to the knowledge base
const documentExists = await db
.select({ id: document.id })
.from(document)
.where(and(eq(document.id, documentId), eq(document.knowledgeBaseId, knowledgeBaseId)))
.limit(1)
if (documentExists.length === 0) {
return NextResponse.json({ error: 'Document not found' }, { status: 404 })
if (accessCheck.notFound) {
logger.warn(
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
)
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
}
logger.warn(
`[${requestId}] User ${session.user.id} attempted unauthorized document write access: ${accessCheck.reason}`
)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
let body
@@ -238,197 +123,24 @@ export async function POST(
const validatedData = BulkTagDefinitionsSchema.parse(body)
// Validate slots are valid for their field types
for (const definition of validatedData.definitions) {
const validSlots = getSlotsForFieldType(definition.fieldType)
if (validSlots.length === 0) {
return NextResponse.json(
{ error: `Unsupported field type: ${definition.fieldType}` },
{ status: 400 }
)
}
if (!validSlots.includes(definition.tagSlot)) {
return NextResponse.json(
{
error: `Invalid slot '${definition.tagSlot}' for field type '${definition.fieldType}'. Valid slots: ${validSlots.join(', ')}`,
},
{ status: 400 }
)
}
const bulkData: BulkTagDefinitionsData = {
definitions: validatedData.definitions.map((def) => ({
tagSlot: def.tagSlot,
displayName: def.displayName,
fieldType: def.fieldType,
originalDisplayName: def._originalDisplayName,
})),
}
// Validate no duplicate tag slots within the same field type
const slotsByFieldType = new Map<string, Set<string>>()
for (const definition of validatedData.definitions) {
if (!slotsByFieldType.has(definition.fieldType)) {
slotsByFieldType.set(definition.fieldType, new Set())
}
const slotsForType = slotsByFieldType.get(definition.fieldType)!
if (slotsForType.has(definition.tagSlot)) {
return NextResponse.json(
{
error: `Duplicate slot '${definition.tagSlot}' for field type '${definition.fieldType}'`,
},
{ status: 400 }
)
}
slotsForType.add(definition.tagSlot)
}
const now = new Date()
const createdDefinitions: (typeof knowledgeBaseTagDefinitions.$inferSelect)[] = []
// Get existing definitions
const existingDefinitions = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
// Group by field type for validation
const existingByFieldType = new Map<string, number>()
for (const def of existingDefinitions) {
existingByFieldType.set(def.fieldType, (existingByFieldType.get(def.fieldType) || 0) + 1)
}
// Validate we don't exceed limits per field type
const newByFieldType = new Map<string, number>()
for (const definition of validatedData.definitions) {
// Skip validation for edit operations - they don't create new slots
if (definition._originalDisplayName) {
continue
}
const existingTagNames = new Set(
existingDefinitions
.filter((def) => def.fieldType === definition.fieldType)
.map((def) => def.displayName)
)
if (!existingTagNames.has(definition.displayName)) {
newByFieldType.set(
definition.fieldType,
(newByFieldType.get(definition.fieldType) || 0) + 1
)
}
}
for (const [fieldType, newCount] of newByFieldType.entries()) {
const existingCount = existingByFieldType.get(fieldType) || 0
const maxSlots = getMaxSlotsForFieldType(fieldType)
if (existingCount + newCount > maxSlots) {
return NextResponse.json(
{
error: `Cannot create ${newCount} new '${fieldType}' tags. Knowledge base already has ${existingCount} '${fieldType}' tag definitions. Maximum is ${maxSlots} per field type.`,
},
{ status: 400 }
)
}
}
// Use transaction to ensure consistency
await db.transaction(async (tx) => {
// Create maps for lookups
const existingByName = new Map(existingDefinitions.map((def) => [def.displayName, def]))
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot, def]))
// Process each definition
for (const definition of validatedData.definitions) {
if (definition._originalDisplayName) {
// This is an EDIT operation - find by original name and update
const originalDefinition = existingByName.get(definition._originalDisplayName)
if (originalDefinition) {
logger.info(
`[${requestId}] Editing tag definition: ${definition._originalDisplayName} -> ${definition.displayName} (slot ${originalDefinition.tagSlot})`
)
await tx
.update(knowledgeBaseTagDefinitions)
.set({
displayName: definition.displayName,
fieldType: definition.fieldType,
updatedAt: now,
})
.where(eq(knowledgeBaseTagDefinitions.id, originalDefinition.id))
createdDefinitions.push({
...originalDefinition,
displayName: definition.displayName,
fieldType: definition.fieldType,
updatedAt: now,
})
continue
}
logger.warn(
`[${requestId}] Could not find original definition for: ${definition._originalDisplayName}`
)
}
// Regular create/update logic
const existingByDisplayName = existingByName.get(definition.displayName)
if (existingByDisplayName) {
// Display name exists - UPDATE operation
logger.info(
`[${requestId}] Updating existing tag definition: ${definition.displayName} (slot ${existingByDisplayName.tagSlot})`
)
await tx
.update(knowledgeBaseTagDefinitions)
.set({
fieldType: definition.fieldType,
updatedAt: now,
})
.where(eq(knowledgeBaseTagDefinitions.id, existingByDisplayName.id))
createdDefinitions.push({
...existingByDisplayName,
fieldType: definition.fieldType,
updatedAt: now,
})
} else {
// Display name doesn't exist - CREATE operation
const targetSlot = await getNextAvailableSlot(
knowledgeBaseId,
definition.fieldType,
existingBySlot
)
if (!targetSlot) {
logger.error(
`[${requestId}] No available slots for new tag definition: ${definition.displayName}`
)
continue
}
logger.info(
`[${requestId}] Creating new tag definition: ${definition.displayName} -> ${targetSlot}`
)
const newDefinition = {
id: randomUUID(),
knowledgeBaseId,
tagSlot: targetSlot as any,
displayName: definition.displayName,
fieldType: definition.fieldType,
createdAt: now,
updatedAt: now,
}
await tx.insert(knowledgeBaseTagDefinitions).values(newDefinition)
existingBySlot.set(targetSlot as any, newDefinition)
createdDefinitions.push(newDefinition as any)
}
}
})
logger.info(`[${requestId}] Created/updated ${createdDefinitions.length} tag definitions`)
const result = await createOrUpdateTagDefinitionsBulk(knowledgeBaseId, bulkData, requestId)
return NextResponse.json({
success: true,
data: createdDefinitions,
data: {
created: result.created,
updated: result.updated,
errors: result.errors,
},
})
} catch (error) {
if (error instanceof z.ZodError) {
@@ -459,10 +171,19 @@ export async function DELETE(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has write access to the knowledge base
const accessCheck = await checkKnowledgeBaseWriteAccess(knowledgeBaseId, session.user.id)
// Verify document exists and user has write access
const accessCheck = await checkDocumentWriteAccess(knowledgeBaseId, documentId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
if (accessCheck.notFound) {
logger.warn(
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
)
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
}
logger.warn(
`[${requestId}] User ${session.user.id} attempted unauthorized document write access: ${accessCheck.reason}`
)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
if (action === 'cleanup') {
@@ -478,13 +199,12 @@ export async function DELETE(
// Delete all tag definitions (original behavior)
logger.info(`[${requestId}] Deleting all tag definitions for KB ${knowledgeBaseId}`)
const result = await db
.delete(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
const deletedCount = await deleteAllTagDefinitions(knowledgeBaseId, requestId)
return NextResponse.json({
success: true,
message: 'Tag definitions deleted successfully',
data: { deleted: deletedCount },
})
} catch (error) {
logger.error(`[${requestId}] Error with tag definitions operation`, error)

View File

@@ -24,6 +24,19 @@ vi.mock('@/app/api/knowledge/utils', () => ({
processDocumentAsync: vi.fn(),
}))
vi.mock('@/lib/knowledge/documents/service', () => ({
getDocuments: vi.fn(),
createSingleDocument: vi.fn(),
createDocumentRecords: vi.fn(),
processDocumentsWithQueue: vi.fn(),
getProcessingConfig: vi.fn(),
bulkDocumentOperation: vi.fn(),
updateDocument: vi.fn(),
deleteDocument: vi.fn(),
markDocumentAsFailedTimeout: vi.fn(),
retryDocumentProcessing: vi.fn(),
}))
mockDrizzleOrm()
mockConsoleLogger()
@@ -72,7 +85,6 @@ describe('Knowledge Base Documents API Route', () => {
}
}
})
// Clear all mocks - they will be set up in individual tests
}
beforeEach(async () => {
@@ -96,6 +108,7 @@ describe('Knowledge Base Documents API Route', () => {
it('should retrieve documents successfully for authenticated user', async () => {
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
const { getDocuments } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
@@ -103,11 +116,15 @@ describe('Knowledge Base Documents API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock the count query (first query)
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
// Mock the documents query (second query)
mockDbChain.offset.mockResolvedValue([mockDocument])
vi.mocked(getDocuments).mockResolvedValue({
documents: [mockDocument],
pagination: {
total: 1,
limit: 50,
offset: 0,
hasMore: false,
},
})
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
@@ -118,12 +135,22 @@ describe('Knowledge Base Documents API Route', () => {
expect(data.success).toBe(true)
expect(data.data.documents).toHaveLength(1)
expect(data.data.documents[0].id).toBe('doc-123')
expect(mockDbChain.select).toHaveBeenCalled()
expect(vi.mocked(checkKnowledgeBaseAccess)).toHaveBeenCalledWith('kb-123', 'user-123')
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
'kb-123',
{
includeDisabled: false,
search: undefined,
limit: 50,
offset: 0,
},
expect.any(String)
)
})
it('should filter disabled documents by default', async () => {
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
const { getDocuments } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
@@ -131,22 +158,36 @@ describe('Knowledge Base Documents API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock the count query (first query)
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
// Mock the documents query (second query)
mockDbChain.offset.mockResolvedValue([mockDocument])
vi.mocked(getDocuments).mockResolvedValue({
documents: [mockDocument],
pagination: {
total: 1,
limit: 50,
offset: 0,
hasMore: false,
},
})
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
const response = await GET(req, { params: mockParams })
expect(response.status).toBe(200)
expect(mockDbChain.where).toHaveBeenCalled()
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
'kb-123',
{
includeDisabled: false,
search: undefined,
limit: 50,
offset: 0,
},
expect.any(String)
)
})
it('should include disabled documents when requested', async () => {
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
const { getDocuments } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
@@ -154,11 +195,15 @@ describe('Knowledge Base Documents API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock the count query (first query)
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
// Mock the documents query (second query)
mockDbChain.offset.mockResolvedValue([mockDocument])
vi.mocked(getDocuments).mockResolvedValue({
documents: [mockDocument],
pagination: {
total: 1,
limit: 50,
offset: 0,
hasMore: false,
},
})
const url = 'http://localhost:3000/api/knowledge/kb-123/documents?includeDisabled=true'
const req = new Request(url, { method: 'GET' }) as any
@@ -167,6 +212,16 @@ describe('Knowledge Base Documents API Route', () => {
const response = await GET(req, { params: mockParams })
expect(response.status).toBe(200)
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
'kb-123',
{
includeDisabled: true,
search: undefined,
limit: 50,
offset: 0,
},
expect.any(String)
)
})
it('should return unauthorized for unauthenticated user', async () => {
@@ -216,13 +271,14 @@ describe('Knowledge Base Documents API Route', () => {
it('should handle database errors', async () => {
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
const { getDocuments } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.orderBy.mockRejectedValue(new Error('Database error'))
vi.mocked(getDocuments).mockRejectedValue(new Error('Database error'))
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
@@ -245,13 +301,35 @@ describe('Knowledge Base Documents API Route', () => {
it('should create single document successfully', async () => {
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
const { createSingleDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.values.mockResolvedValue(undefined)
const createdDocument = {
id: 'doc-123',
knowledgeBaseId: 'kb-123',
filename: validDocumentData.filename,
fileUrl: validDocumentData.fileUrl,
fileSize: validDocumentData.fileSize,
mimeType: validDocumentData.mimeType,
chunkCount: 0,
tokenCount: 0,
characterCount: 0,
enabled: true,
uploadedAt: new Date(),
tag1: null,
tag2: null,
tag3: null,
tag4: null,
tag5: null,
tag6: null,
tag7: null,
}
vi.mocked(createSingleDocument).mockResolvedValue(createdDocument)
const req = createMockRequest('POST', validDocumentData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
@@ -262,7 +340,11 @@ describe('Knowledge Base Documents API Route', () => {
expect(data.success).toBe(true)
expect(data.data.filename).toBe(validDocumentData.filename)
expect(data.data.fileUrl).toBe(validDocumentData.fileUrl)
expect(mockDbChain.insert).toHaveBeenCalled()
expect(vi.mocked(createSingleDocument)).toHaveBeenCalledWith(
validDocumentData,
'kb-123',
expect.any(String)
)
})
it('should validate single document data', async () => {
@@ -320,9 +402,9 @@ describe('Knowledge Base Documents API Route', () => {
}
it('should create bulk documents successfully', async () => {
const { checkKnowledgeBaseWriteAccess, processDocumentAsync } = await import(
'@/app/api/knowledge/utils'
)
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
const { createDocumentRecords, processDocumentsWithQueue, getProcessingConfig } =
await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
@@ -330,17 +412,31 @@ describe('Knowledge Base Documents API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock transaction to return the created documents
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
insert: vi.fn().mockReturnValue({
values: vi.fn().mockResolvedValue(undefined),
}),
}
return await callback(mockTx)
})
const createdDocuments = [
{
documentId: 'doc-1',
filename: 'doc1.pdf',
fileUrl: 'https://example.com/doc1.pdf',
fileSize: 1024,
mimeType: 'application/pdf',
},
{
documentId: 'doc-2',
filename: 'doc2.pdf',
fileUrl: 'https://example.com/doc2.pdf',
fileSize: 2048,
mimeType: 'application/pdf',
},
]
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
vi.mocked(createDocumentRecords).mockResolvedValue(createdDocuments)
vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined)
vi.mocked(getProcessingConfig).mockReturnValue({
maxConcurrentDocuments: 8,
batchSize: 20,
delayBetweenBatches: 100,
delayBetweenDocuments: 0,
})
const req = createMockRequest('POST', validBulkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
@@ -352,7 +448,12 @@ describe('Knowledge Base Documents API Route', () => {
expect(data.data.total).toBe(2)
expect(data.data.documentsCreated).toHaveLength(2)
expect(data.data.processingMethod).toBe('background')
expect(mockDbChain.transaction).toHaveBeenCalled()
expect(vi.mocked(createDocumentRecords)).toHaveBeenCalledWith(
validBulkData.documents,
'kb-123',
expect.any(String)
)
expect(vi.mocked(processDocumentsWithQueue)).toHaveBeenCalled()
})
it('should validate bulk document data', async () => {
@@ -394,9 +495,9 @@ describe('Knowledge Base Documents API Route', () => {
})
it('should handle processing errors gracefully', async () => {
const { checkKnowledgeBaseWriteAccess, processDocumentAsync } = await import(
'@/app/api/knowledge/utils'
)
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
const { createDocumentRecords, processDocumentsWithQueue, getProcessingConfig } =
await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
@@ -404,26 +505,30 @@ describe('Knowledge Base Documents API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock transaction to succeed but processing to fail
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
insert: vi.fn().mockReturnValue({
values: vi.fn().mockResolvedValue(undefined),
}),
}
return await callback(mockTx)
})
const createdDocuments = [
{
documentId: 'doc-1',
filename: 'doc1.pdf',
fileUrl: 'https://example.com/doc1.pdf',
fileSize: 1024,
mimeType: 'application/pdf',
},
]
// Don't reject the promise - the processing is async and catches errors internally
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
vi.mocked(createDocumentRecords).mockResolvedValue(createdDocuments)
vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined)
vi.mocked(getProcessingConfig).mockReturnValue({
maxConcurrentDocuments: 8,
batchSize: 20,
delayBetweenBatches: 100,
delayBetweenDocuments: 0,
})
const req = createMockRequest('POST', validBulkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
// The endpoint should still return success since documents are created
// and processing happens asynchronously
expect(response.status).toBe(200)
expect(data.success).toBe(true)
})
@@ -485,13 +590,14 @@ describe('Knowledge Base Documents API Route', () => {
it('should handle database errors during creation', async () => {
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
const { createSingleDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.values.mockRejectedValue(new Error('Database error'))
vi.mocked(createSingleDocument).mockRejectedValue(new Error('Database error'))
const req = createMockRequest('POST', validDocumentData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')

View File

@@ -1,279 +1,22 @@
import { randomUUID } from 'crypto'
import { and, desc, eq, inArray, isNull, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { getSlotsForFieldType } from '@/lib/constants/knowledge'
import {
bulkDocumentOperation,
createDocumentRecords,
createSingleDocument,
getDocuments,
getProcessingConfig,
processDocumentsWithQueue,
} from '@/lib/knowledge/documents/service'
import type { DocumentSortField, SortOrder } from '@/lib/knowledge/documents/types'
import { createLogger } from '@/lib/logs/console/logger'
import { getUserId } from '@/app/api/auth/oauth/utils'
import {
checkKnowledgeBaseAccess,
checkKnowledgeBaseWriteAccess,
processDocumentAsync,
} from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
const logger = createLogger('DocumentsAPI')
const PROCESSING_CONFIG = {
maxConcurrentDocuments: 3,
batchSize: 5,
delayBetweenBatches: 1000,
delayBetweenDocuments: 500,
}
// Helper function to get the next available slot for a knowledge base and field type
async function getNextAvailableSlot(
knowledgeBaseId: string,
fieldType: string,
existingBySlot?: Map<string, any>
): Promise<string | null> {
let usedSlots: Set<string>
if (existingBySlot) {
// Use provided map if available (for performance in batch operations)
// Filter by field type
usedSlots = new Set(
Array.from(existingBySlot.entries())
.filter(([_, def]) => def.fieldType === fieldType)
.map(([slot, _]) => slot)
)
} else {
// Query database for existing tag definitions of the same field type
const existingDefinitions = await db
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
)
)
usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot))
}
// Find the first available slot for this field type
const availableSlots = getSlotsForFieldType(fieldType)
for (const slot of availableSlots) {
if (!usedSlots.has(slot)) {
return slot
}
}
return null // No available slots for this field type
}
// Helper function to process structured document tags
async function processDocumentTags(
knowledgeBaseId: string,
tagData: Array<{ tagName: string; fieldType: string; value: string }>,
requestId: string
): Promise<Record<string, string | null>> {
const result: Record<string, string | null> = {}
// Initialize all text tag slots to null (only text type is supported currently)
const textSlots = getSlotsForFieldType('text')
textSlots.forEach((slot) => {
result[slot] = null
})
if (!Array.isArray(tagData) || tagData.length === 0) {
return result
}
try {
// Get existing tag definitions
const existingDefinitions = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
const existingByName = new Map(existingDefinitions.map((def) => [def.displayName, def]))
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot, def]))
// Process each tag
for (const tag of tagData) {
if (!tag.tagName?.trim() || !tag.value?.trim()) continue
const tagName = tag.tagName.trim()
const fieldType = tag.fieldType
const value = tag.value.trim()
let targetSlot: string | null = null
// Check if tag definition already exists
const existingDef = existingByName.get(tagName)
if (existingDef) {
targetSlot = existingDef.tagSlot
} else {
// Find next available slot using the helper function
targetSlot = await getNextAvailableSlot(knowledgeBaseId, fieldType, existingBySlot)
// Create new tag definition if we have a slot
if (targetSlot) {
const newDefinition = {
id: randomUUID(),
knowledgeBaseId,
tagSlot: targetSlot as any,
displayName: tagName,
fieldType,
createdAt: new Date(),
updatedAt: new Date(),
}
await db.insert(knowledgeBaseTagDefinitions).values(newDefinition)
existingBySlot.set(targetSlot as any, newDefinition)
logger.info(`[${requestId}] Created tag definition: ${tagName} -> ${targetSlot}`)
}
}
// Assign value to the slot
if (targetSlot) {
result[targetSlot] = value
}
}
return result
} catch (error) {
logger.error(`[${requestId}] Error processing document tags:`, error)
return result
}
}
async function processDocumentsWithConcurrencyControl(
createdDocuments: Array<{
documentId: string
filename: string
fileUrl: string
fileSize: number
mimeType: string
}>,
knowledgeBaseId: string,
processingOptions: {
chunkSize: number
minCharactersPerChunk: number
recipe: string
lang: string
chunkOverlap: number
},
requestId: string
): Promise<void> {
const totalDocuments = createdDocuments.length
const batches = []
for (let i = 0; i < totalDocuments; i += PROCESSING_CONFIG.batchSize) {
batches.push(createdDocuments.slice(i, i + PROCESSING_CONFIG.batchSize))
}
logger.info(`[${requestId}] Processing ${totalDocuments} documents in ${batches.length} batches`)
for (const [batchIndex, batch] of batches.entries()) {
logger.info(
`[${requestId}] Starting batch ${batchIndex + 1}/${batches.length} with ${batch.length} documents`
)
await processBatchWithConcurrency(batch, knowledgeBaseId, processingOptions, requestId)
if (batchIndex < batches.length - 1) {
await new Promise((resolve) => setTimeout(resolve, PROCESSING_CONFIG.delayBetweenBatches))
}
}
logger.info(`[${requestId}] Completed processing initiation for all ${totalDocuments} documents`)
}
async function processBatchWithConcurrency(
batch: Array<{
documentId: string
filename: string
fileUrl: string
fileSize: number
mimeType: string
}>,
knowledgeBaseId: string,
processingOptions: {
chunkSize: number
minCharactersPerChunk: number
recipe: string
lang: string
chunkOverlap: number
},
requestId: string
): Promise<void> {
const semaphore = new Array(PROCESSING_CONFIG.maxConcurrentDocuments).fill(0)
const processingPromises = batch.map(async (doc, index) => {
if (index > 0) {
await new Promise((resolve) =>
setTimeout(resolve, index * PROCESSING_CONFIG.delayBetweenDocuments)
)
}
await new Promise<void>((resolve) => {
const checkSlot = () => {
const availableIndex = semaphore.findIndex((slot) => slot === 0)
if (availableIndex !== -1) {
semaphore[availableIndex] = 1
resolve()
} else {
setTimeout(checkSlot, 100)
}
}
checkSlot()
})
try {
logger.info(`[${requestId}] Starting processing for document: ${doc.filename}`)
await processDocumentAsync(
knowledgeBaseId,
doc.documentId,
{
filename: doc.filename,
fileUrl: doc.fileUrl,
fileSize: doc.fileSize,
mimeType: doc.mimeType,
},
processingOptions
)
logger.info(`[${requestId}] Successfully initiated processing for document: ${doc.filename}`)
} catch (error: unknown) {
logger.error(`[${requestId}] Failed to process document: ${doc.filename}`, {
documentId: doc.documentId,
filename: doc.filename,
error: error instanceof Error ? error.message : 'Unknown error',
})
try {
await db
.update(document)
.set({
processingStatus: 'failed',
processingError:
error instanceof Error ? error.message : 'Failed to initiate processing',
processingCompletedAt: new Date(),
})
.where(eq(document.id, doc.documentId))
} catch (dbError: unknown) {
logger.error(
`[${requestId}] Failed to update document status for failed document: ${doc.documentId}`,
dbError
)
}
} finally {
const slotIndex = semaphore.findIndex((slot) => slot === 1)
if (slotIndex !== -1) {
semaphore[slotIndex] = 0
}
}
})
await Promise.allSettled(processingPromises)
}
const CreateDocumentSchema = z.object({
filename: z.string().min(1, 'Filename is required'),
fileUrl: z.string().url('File URL must be valid'),
@@ -337,83 +80,50 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
const url = new URL(req.url)
const includeDisabled = url.searchParams.get('includeDisabled') === 'true'
const search = url.searchParams.get('search')
const search = url.searchParams.get('search') || undefined
const limit = Number.parseInt(url.searchParams.get('limit') || '50')
const offset = Number.parseInt(url.searchParams.get('offset') || '0')
const sortByParam = url.searchParams.get('sortBy')
const sortOrderParam = url.searchParams.get('sortOrder')
// Build where conditions
const whereConditions = [
eq(document.knowledgeBaseId, knowledgeBaseId),
isNull(document.deletedAt),
// Validate sort parameters
const validSortFields: DocumentSortField[] = [
'filename',
'fileSize',
'tokenCount',
'chunkCount',
'uploadedAt',
'processingStatus',
]
const validSortOrders: SortOrder[] = ['asc', 'desc']
// Filter out disabled documents unless specifically requested
if (!includeDisabled) {
whereConditions.push(eq(document.enabled, true))
}
const sortBy =
sortByParam && validSortFields.includes(sortByParam as DocumentSortField)
? (sortByParam as DocumentSortField)
: undefined
const sortOrder =
sortOrderParam && validSortOrders.includes(sortOrderParam as SortOrder)
? (sortOrderParam as SortOrder)
: undefined
// Add search condition if provided
if (search) {
whereConditions.push(
// Search in filename
sql`LOWER(${document.filename}) LIKE LOWER(${`%${search}%`})`
)
}
// Get total count for pagination
const totalResult = await db
.select({ count: sql<number>`COUNT(*)` })
.from(document)
.where(and(...whereConditions))
const total = totalResult[0]?.count || 0
const hasMore = offset + limit < total
const documents = await db
.select({
id: document.id,
filename: document.filename,
fileUrl: document.fileUrl,
fileSize: document.fileSize,
mimeType: document.mimeType,
chunkCount: document.chunkCount,
tokenCount: document.tokenCount,
characterCount: document.characterCount,
processingStatus: document.processingStatus,
processingStartedAt: document.processingStartedAt,
processingCompletedAt: document.processingCompletedAt,
processingError: document.processingError,
enabled: document.enabled,
uploadedAt: document.uploadedAt,
// Include tags in response
tag1: document.tag1,
tag2: document.tag2,
tag3: document.tag3,
tag4: document.tag4,
tag5: document.tag5,
tag6: document.tag6,
tag7: document.tag7,
})
.from(document)
.where(and(...whereConditions))
.orderBy(desc(document.uploadedAt))
.limit(limit)
.offset(offset)
logger.info(
`[${requestId}] Retrieved ${documents.length} documents (${offset}-${offset + documents.length} of ${total}) for knowledge base ${knowledgeBaseId}`
const result = await getDocuments(
knowledgeBaseId,
{
includeDisabled,
search,
limit,
offset,
...(sortBy && { sortBy }),
...(sortOrder && { sortOrder }),
},
requestId
)
return NextResponse.json({
success: true,
data: {
documents,
pagination: {
total,
limit,
offset,
hasMore,
},
documents: result.documents,
pagination: result.pagination,
},
})
} catch (error) {
@@ -462,80 +172,21 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if this is a bulk operation
if (body.bulk === true) {
// Handle bulk processing (replaces process-documents endpoint)
try {
const validatedData = BulkCreateDocumentsSchema.parse(body)
const createdDocuments = await db.transaction(async (tx) => {
const documentPromises = validatedData.documents.map(async (docData) => {
const documentId = randomUUID()
const now = new Date()
// Process documentTagsData if provided (for knowledge base block)
let processedTags: Record<string, string | null> = {
tag1: null,
tag2: null,
tag3: null,
tag4: null,
tag5: null,
tag6: null,
tag7: null,
}
if (docData.documentTagsData) {
try {
const tagData = JSON.parse(docData.documentTagsData)
if (Array.isArray(tagData)) {
processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId)
}
} catch (error) {
logger.warn(
`[${requestId}] Failed to parse documentTagsData for bulk document:`,
error
)
}
}
const newDocument = {
id: documentId,
knowledgeBaseId,
filename: docData.filename,
fileUrl: docData.fileUrl,
fileSize: docData.fileSize,
mimeType: docData.mimeType,
chunkCount: 0,
tokenCount: 0,
characterCount: 0,
processingStatus: 'pending' as const,
enabled: true,
uploadedAt: now,
// Use processed tags if available, otherwise fall back to individual tag fields
tag1: processedTags.tag1 || docData.tag1 || null,
tag2: processedTags.tag2 || docData.tag2 || null,
tag3: processedTags.tag3 || docData.tag3 || null,
tag4: processedTags.tag4 || docData.tag4 || null,
tag5: processedTags.tag5 || docData.tag5 || null,
tag6: processedTags.tag6 || docData.tag6 || null,
tag7: processedTags.tag7 || docData.tag7 || null,
}
await tx.insert(document).values(newDocument)
logger.info(
`[${requestId}] Document record created: ${documentId} for file: ${docData.filename}`
)
return { documentId, ...docData }
})
return await Promise.all(documentPromises)
})
const createdDocuments = await createDocumentRecords(
validatedData.documents,
knowledgeBaseId,
requestId
)
logger.info(
`[${requestId}] Starting controlled async processing of ${createdDocuments.length} documents`
)
processDocumentsWithConcurrencyControl(
processDocumentsWithQueue(
createdDocuments,
knowledgeBaseId,
validatedData.processingOptions,
@@ -555,9 +206,9 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
})),
processingMethod: 'background',
processingConfig: {
maxConcurrentDocuments: PROCESSING_CONFIG.maxConcurrentDocuments,
batchSize: PROCESSING_CONFIG.batchSize,
totalBatches: Math.ceil(createdDocuments.length / PROCESSING_CONFIG.batchSize),
maxConcurrentDocuments: getProcessingConfig().maxConcurrentDocuments,
batchSize: getProcessingConfig().batchSize,
totalBatches: Math.ceil(createdDocuments.length / getProcessingConfig().batchSize),
},
},
})
@@ -578,52 +229,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
try {
const validatedData = CreateDocumentSchema.parse(body)
const documentId = randomUUID()
const now = new Date()
// Process structured tag data if provided
let processedTags: Record<string, string | null> = {
tag1: validatedData.tag1 || null,
tag2: validatedData.tag2 || null,
tag3: validatedData.tag3 || null,
tag4: validatedData.tag4 || null,
tag5: validatedData.tag5 || null,
tag6: validatedData.tag6 || null,
tag7: validatedData.tag7 || null,
}
if (validatedData.documentTagsData) {
try {
const tagData = JSON.parse(validatedData.documentTagsData)
if (Array.isArray(tagData)) {
// Process structured tag data and create tag definitions
processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId)
}
} catch (error) {
logger.warn(`[${requestId}] Failed to parse documentTagsData:`, error)
}
}
const newDocument = {
id: documentId,
knowledgeBaseId,
filename: validatedData.filename,
fileUrl: validatedData.fileUrl,
fileSize: validatedData.fileSize,
mimeType: validatedData.mimeType,
chunkCount: 0,
tokenCount: 0,
characterCount: 0,
enabled: true,
uploadedAt: now,
...processedTags,
}
await db.insert(document).values(newDocument)
logger.info(
`[${requestId}] Document created: ${documentId} in knowledge base ${knowledgeBaseId}`
)
const newDocument = await createSingleDocument(validatedData, knowledgeBaseId, requestId)
return NextResponse.json({
success: true,
@@ -649,7 +255,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
}
export async function PATCH(req: NextRequest, { params }: { params: Promise<{ id: string }> }) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
const { id: knowledgeBaseId } = await params
try {
@@ -678,89 +284,28 @@ export async function PATCH(req: NextRequest, { params }: { params: Promise<{ id
const validatedData = BulkUpdateDocumentsSchema.parse(body)
const { operation, documentIds } = validatedData
logger.info(
`[${requestId}] Starting bulk ${operation} operation on ${documentIds.length} documents in knowledge base ${knowledgeBaseId}`
)
// Verify all documents belong to this knowledge base and user has access
const documentsToUpdate = await db
.select({
id: document.id,
enabled: document.enabled,
})
.from(document)
.where(
and(
eq(document.knowledgeBaseId, knowledgeBaseId),
inArray(document.id, documentIds),
isNull(document.deletedAt)
)
)
if (documentsToUpdate.length === 0) {
return NextResponse.json({ error: 'No valid documents found to update' }, { status: 404 })
}
if (documentsToUpdate.length !== documentIds.length) {
logger.warn(
`[${requestId}] Some documents not found or don't belong to knowledge base. Requested: ${documentIds.length}, Found: ${documentsToUpdate.length}`
)
}
// Perform the bulk operation
let updateResult: Array<{ id: string; enabled?: boolean; deletedAt?: Date | null }>
let successCount: number
if (operation === 'delete') {
// Handle bulk soft delete
updateResult = await db
.update(document)
.set({
deletedAt: new Date(),
})
.where(
and(
eq(document.knowledgeBaseId, knowledgeBaseId),
inArray(document.id, documentIds),
isNull(document.deletedAt)
)
)
.returning({ id: document.id, deletedAt: document.deletedAt })
successCount = updateResult.length
} else {
// Handle bulk enable/disable
const enabled = operation === 'enable'
updateResult = await db
.update(document)
.set({
enabled,
})
.where(
and(
eq(document.knowledgeBaseId, knowledgeBaseId),
inArray(document.id, documentIds),
isNull(document.deletedAt)
)
)
.returning({ id: document.id, enabled: document.enabled })
successCount = updateResult.length
}
logger.info(
`[${requestId}] Bulk ${operation} operation completed: ${successCount} documents updated in knowledge base ${knowledgeBaseId}`
)
return NextResponse.json({
success: true,
data: {
try {
const result = await bulkDocumentOperation(
knowledgeBaseId,
operation,
successCount,
updatedDocuments: updateResult,
},
})
documentIds,
requestId
)
return NextResponse.json({
success: true,
data: {
operation,
successCount: result.successCount,
updatedDocuments: result.updatedDocuments,
},
})
} catch (error) {
if (error instanceof Error && error.message === 'No valid documents found to update') {
return NextResponse.json({ error: 'No valid documents found to update' }, { status: 404 })
}
throw error
}
} catch (validationError) {
if (validationError instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid bulk operation data`, {

View File

@@ -1,12 +1,9 @@
import { randomUUID } from 'crypto'
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { getMaxSlotsForFieldType, getSlotsForFieldType } from '@/lib/constants/knowledge'
import { getNextAvailableSlot, getTagDefinitions } from '@/lib/knowledge/tags/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { knowledgeBaseTagDefinitions } from '@/db/schema'
const logger = createLogger('NextAvailableSlotAPI')
@@ -31,51 +28,36 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has read access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Get available slots for this field type
const availableSlots = getSlotsForFieldType(fieldType)
const maxSlots = getMaxSlotsForFieldType(fieldType)
// Get existing definitions once and reuse
const existingDefinitions = await getTagDefinitions(knowledgeBaseId)
const usedSlots = existingDefinitions
.filter((def) => def.fieldType === fieldType)
.map((def) => def.tagSlot)
// Get existing tag definitions to find used slots for this field type
const existingDefinitions = await db
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
)
)
const usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot as string))
// Find the first available slot for this field type
let nextAvailableSlot: string | null = null
for (const slot of availableSlots) {
if (!usedSlots.has(slot)) {
nextAvailableSlot = slot
break
}
}
// Create a map for efficient lookup and pass to avoid redundant query
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot as string, def]))
const nextAvailableSlot = await getNextAvailableSlot(knowledgeBaseId, fieldType, existingBySlot)
logger.info(
`[${requestId}] Next available slot for fieldType ${fieldType}: ${nextAvailableSlot}`
)
const result = {
nextAvailableSlot,
fieldType,
usedSlots,
totalSlots: 7,
availableSlots: nextAvailableSlot ? 7 - usedSlots.length : 0,
}
return NextResponse.json({
success: true,
data: {
nextAvailableSlot,
fieldType,
usedSlots: Array.from(usedSlots),
totalSlots: maxSlots,
availableSlots: maxSlots - usedSlots.size,
},
data: result,
})
} catch (error) {
logger.error(`[${requestId}] Error getting next available slot`, error)

View File

@@ -16,9 +16,26 @@ mockKnowledgeSchemas()
mockDrizzleOrm()
mockConsoleLogger()
vi.mock('@/lib/knowledge/service', () => ({
getKnowledgeBaseById: vi.fn(),
updateKnowledgeBase: vi.fn(),
deleteKnowledgeBase: vi.fn(),
}))
vi.mock('@/app/api/knowledge/utils', () => ({
checkKnowledgeBaseAccess: vi.fn(),
checkKnowledgeBaseWriteAccess: vi.fn(),
}))
describe('Knowledge Base By ID API Route', () => {
const mockAuth$ = mockAuth()
let mockGetKnowledgeBaseById: any
let mockUpdateKnowledgeBase: any
let mockDeleteKnowledgeBase: any
let mockCheckKnowledgeBaseAccess: any
let mockCheckKnowledgeBaseWriteAccess: any
const mockDbChain = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
@@ -62,6 +79,15 @@ describe('Knowledge Base By ID API Route', () => {
vi.stubGlobal('crypto', {
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),
})
const knowledgeService = await import('@/lib/knowledge/service')
const knowledgeUtils = await import('@/app/api/knowledge/utils')
mockGetKnowledgeBaseById = knowledgeService.getKnowledgeBaseById as any
mockUpdateKnowledgeBase = knowledgeService.updateKnowledgeBase as any
mockDeleteKnowledgeBase = knowledgeService.deleteKnowledgeBase as any
mockCheckKnowledgeBaseAccess = knowledgeUtils.checkKnowledgeBaseAccess as any
mockCheckKnowledgeBaseWriteAccess = knowledgeUtils.checkKnowledgeBaseWriteAccess as any
})
afterEach(() => {
@@ -74,9 +100,12 @@ describe('Knowledge Base By ID API Route', () => {
it('should retrieve knowledge base successfully for authenticated user', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.limit.mockResolvedValueOnce([mockKnowledgeBase])
mockGetKnowledgeBaseById.mockResolvedValueOnce(mockKnowledgeBase)
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/route')
@@ -87,7 +116,8 @@ describe('Knowledge Base By ID API Route', () => {
expect(data.success).toBe(true)
expect(data.data.id).toBe('kb-123')
expect(data.data.name).toBe('Test Knowledge Base')
expect(mockDbChain.select).toHaveBeenCalled()
expect(mockCheckKnowledgeBaseAccess).toHaveBeenCalledWith('kb-123', 'user-123')
expect(mockGetKnowledgeBaseById).toHaveBeenCalledWith('kb-123')
})
it('should return unauthorized for unauthenticated user', async () => {
@@ -105,7 +135,10 @@ describe('Knowledge Base By ID API Route', () => {
it('should return not found for non-existent knowledge base', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockResolvedValueOnce([])
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
hasAccess: false,
notFound: true,
})
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/route')
@@ -119,7 +152,10 @@ describe('Knowledge Base By ID API Route', () => {
it('should return unauthorized for knowledge base owned by different user', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'different-user' }])
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
hasAccess: false,
notFound: false,
})
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/route')
@@ -130,9 +166,29 @@ describe('Knowledge Base By ID API Route', () => {
expect(data.error).toBe('Unauthorized')
})
it('should return not found when service returns null', async () => {
mockAuth$.mockAuthenticatedUser()
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockGetKnowledgeBaseById.mockResolvedValueOnce(null)
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/route')
const response = await GET(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(404)
expect(data.error).toBe('Knowledge base not found')
})
it('should handle database errors', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockRejectedValueOnce(new Error('Database error'))
mockCheckKnowledgeBaseAccess.mockRejectedValueOnce(new Error('Database error'))
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/route')
@@ -156,13 +212,13 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.where.mockResolvedValueOnce(undefined)
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([{ ...mockKnowledgeBase, ...validUpdateData }])
const updatedKnowledgeBase = { ...mockKnowledgeBase, ...validUpdateData }
mockUpdateKnowledgeBase.mockResolvedValueOnce(updatedKnowledgeBase)
const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/route')
@@ -172,7 +228,16 @@ describe('Knowledge Base By ID API Route', () => {
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(data.data.name).toBe('Updated Knowledge Base')
expect(mockDbChain.update).toHaveBeenCalled()
expect(mockCheckKnowledgeBaseWriteAccess).toHaveBeenCalledWith('kb-123', 'user-123')
expect(mockUpdateKnowledgeBase).toHaveBeenCalledWith(
'kb-123',
{
name: validUpdateData.name,
description: validUpdateData.description,
chunkingConfig: undefined,
},
expect.any(String)
)
})
it('should return unauthorized for unauthenticated user', async () => {
@@ -192,8 +257,10 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: false,
notFound: true,
})
const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/route')
@@ -209,8 +276,10 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
const invalidData = {
name: '',
@@ -229,9 +298,13 @@ describe('Knowledge Base By ID API Route', () => {
it('should handle database errors during update', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
// Mock successful write access check
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.where.mockRejectedValueOnce(new Error('Database error'))
mockUpdateKnowledgeBase.mockRejectedValueOnce(new Error('Database error'))
const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/route')
@@ -251,10 +324,12 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.where.mockResolvedValueOnce(undefined)
mockDeleteKnowledgeBase.mockResolvedValueOnce(undefined)
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
@@ -264,7 +339,8 @@ describe('Knowledge Base By ID API Route', () => {
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(data.data.message).toBe('Knowledge base deleted successfully')
expect(mockDbChain.update).toHaveBeenCalled()
expect(mockCheckKnowledgeBaseWriteAccess).toHaveBeenCalledWith('kb-123', 'user-123')
expect(mockDeleteKnowledgeBase).toHaveBeenCalledWith('kb-123', expect.any(String))
})
it('should return unauthorized for unauthenticated user', async () => {
@@ -284,8 +360,10 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: false,
notFound: true,
})
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
@@ -301,8 +379,10 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'different-user' }])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: false,
notFound: false,
})
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
@@ -316,9 +396,12 @@ describe('Knowledge Base By ID API Route', () => {
it('should handle database errors during delete', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.where.mockRejectedValueOnce(new Error('Database error'))
mockDeleteKnowledgeBase.mockRejectedValueOnce(new Error('Database error'))
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/route')

View File

@@ -1,11 +1,13 @@
import { and, eq, isNull } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import {
deleteKnowledgeBase,
getKnowledgeBaseById,
updateKnowledgeBase,
} from '@/lib/knowledge/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { knowledgeBase } from '@/db/schema'
const logger = createLogger('KnowledgeBaseByIdAPI')
@@ -48,13 +50,9 @@ export async function GET(_req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const knowledgeBases = await db
.select()
.from(knowledgeBase)
.where(and(eq(knowledgeBase.id, id), isNull(knowledgeBase.deletedAt)))
.limit(1)
const knowledgeBaseData = await getKnowledgeBaseById(id)
if (knowledgeBases.length === 0) {
if (!knowledgeBaseData) {
return NextResponse.json({ error: 'Knowledge base not found' }, { status: 404 })
}
@@ -62,7 +60,7 @@ export async function GET(_req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({
success: true,
data: knowledgeBases[0],
data: knowledgeBaseData,
})
} catch (error) {
logger.error(`[${requestId}] Error fetching knowledge base`, error)
@@ -99,42 +97,21 @@ export async function PUT(req: NextRequest, { params }: { params: Promise<{ id:
try {
const validatedData = UpdateKnowledgeBaseSchema.parse(body)
const updateData: any = {
updatedAt: new Date(),
}
if (validatedData.name !== undefined) updateData.name = validatedData.name
if (validatedData.description !== undefined)
updateData.description = validatedData.description
if (validatedData.workspaceId !== undefined)
updateData.workspaceId = validatedData.workspaceId
// Handle embedding model and dimension together to ensure consistency
if (
validatedData.embeddingModel !== undefined ||
validatedData.embeddingDimension !== undefined
) {
updateData.embeddingModel = 'text-embedding-3-small'
updateData.embeddingDimension = 1536
}
if (validatedData.chunkingConfig !== undefined)
updateData.chunkingConfig = validatedData.chunkingConfig
await db.update(knowledgeBase).set(updateData).where(eq(knowledgeBase.id, id))
// Fetch the updated knowledge base
const updatedKnowledgeBase = await db
.select()
.from(knowledgeBase)
.where(eq(knowledgeBase.id, id))
.limit(1)
const updatedKnowledgeBase = await updateKnowledgeBase(
id,
{
name: validatedData.name,
description: validatedData.description,
chunkingConfig: validatedData.chunkingConfig,
},
requestId
)
logger.info(`[${requestId}] Knowledge base updated: ${id} for user ${session.user.id}`)
return NextResponse.json({
success: true,
data: updatedKnowledgeBase[0],
data: updatedKnowledgeBase,
})
} catch (validationError) {
if (validationError instanceof z.ZodError) {
@@ -178,14 +155,7 @@ export async function DELETE(_req: NextRequest, { params }: { params: Promise<{
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Soft delete by setting deletedAt timestamp
await db
.update(knowledgeBase)
.set({
deletedAt: new Date(),
updatedAt: new Date(),
})
.where(eq(knowledgeBase.id, id))
await deleteKnowledgeBase(id, requestId)
logger.info(`[${requestId}] Knowledge base deleted: ${id} for user ${session.user.id}`)

View File

@@ -1,11 +1,9 @@
import { randomUUID } from 'crypto'
import { and, eq, isNotNull } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { deleteTagDefinition } from '@/lib/knowledge/tags/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, embedding, knowledgeBaseTagDefinitions } from '@/db/schema'
export const dynamic = 'force-dynamic'
@@ -29,87 +27,16 @@ export async function DELETE(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Get the tag definition to find which slot it uses
const tagDefinition = await db
.select({
id: knowledgeBaseTagDefinitions.id,
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
})
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.id, tagId),
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId)
)
)
.limit(1)
if (tagDefinition.length === 0) {
return NextResponse.json({ error: 'Tag definition not found' }, { status: 404 })
}
const tagDef = tagDefinition[0]
// Delete the tag definition and clear all document tags in a transaction
await db.transaction(async (tx) => {
logger.info(`[${requestId}] Starting transaction to delete ${tagDef.tagSlot}`)
try {
// Clear the tag from documents that actually have this tag set
logger.info(`[${requestId}] Clearing tag from documents...`)
await tx
.update(document)
.set({ [tagDef.tagSlot]: null })
.where(
and(
eq(document.knowledgeBaseId, knowledgeBaseId),
isNotNull(document[tagDef.tagSlot as keyof typeof document.$inferSelect])
)
)
logger.info(`[${requestId}] Documents updated successfully`)
// Clear the tag from embeddings that actually have this tag set
logger.info(`[${requestId}] Clearing tag from embeddings...`)
await tx
.update(embedding)
.set({ [tagDef.tagSlot]: null })
.where(
and(
eq(embedding.knowledgeBaseId, knowledgeBaseId),
isNotNull(embedding[tagDef.tagSlot as keyof typeof embedding.$inferSelect])
)
)
logger.info(`[${requestId}] Embeddings updated successfully`)
// Delete the tag definition
logger.info(`[${requestId}] Deleting tag definition...`)
await tx
.delete(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.id, tagId))
logger.info(`[${requestId}] Tag definition deleted successfully`)
} catch (error) {
logger.error(`[${requestId}] Error in transaction:`, error)
throw error
}
})
logger.info(
`[${requestId}] Successfully deleted tag definition ${tagDef.displayName} (${tagDef.tagSlot})`
)
const deletedTag = await deleteTagDefinition(tagId, requestId)
return NextResponse.json({
success: true,
message: `Tag definition "${tagDef.displayName}" deleted successfully`,
message: `Tag definition "${deletedTag.displayName}" deleted successfully`,
})
} catch (error) {
logger.error(`[${requestId}] Error deleting tag definition`, error)

View File

@@ -1,11 +1,11 @@
import { randomUUID } from 'crypto'
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { SUPPORTED_FIELD_TYPES } from '@/lib/constants/knowledge'
import { createTagDefinition, getTagDefinitions } from '@/lib/knowledge/tags/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { knowledgeBaseTagDefinitions } from '@/db/schema'
export const dynamic = 'force-dynamic'
@@ -24,25 +24,12 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Get tag definitions for the knowledge base
const tagDefinitions = await db
.select({
id: knowledgeBaseTagDefinitions.id,
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
fieldType: knowledgeBaseTagDefinitions.fieldType,
createdAt: knowledgeBaseTagDefinitions.createdAt,
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
})
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
.orderBy(knowledgeBaseTagDefinitions.tagSlot)
const tagDefinitions = await getTagDefinitions(knowledgeBaseId)
logger.info(`[${requestId}] Retrieved ${tagDefinitions.length} tag definitions`)
@@ -69,68 +56,43 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
const body = await req.json()
const { tagSlot, displayName, fieldType } = body
if (!tagSlot || !displayName || !fieldType) {
return NextResponse.json(
{ error: 'tagSlot, displayName, and fieldType are required' },
{ status: 400 }
)
}
const CreateTagDefinitionSchema = z.object({
tagSlot: z.string().min(1, 'Tag slot is required'),
displayName: z.string().min(1, 'Display name is required'),
fieldType: z.enum(SUPPORTED_FIELD_TYPES as [string, ...string[]], {
errorMap: () => ({ message: 'Invalid field type' }),
}),
})
// Check if tag slot is already used
const existingTag = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseTagDefinitions.tagSlot, tagSlot)
let validatedData
try {
validatedData = CreateTagDefinitionSchema.parse(body)
} catch (error) {
if (error instanceof z.ZodError) {
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
)
.limit(1)
if (existingTag.length > 0) {
return NextResponse.json({ error: 'Tag slot is already in use' }, { status: 409 })
}
throw error
}
// Check if display name is already used
const existingName = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseTagDefinitions.displayName, displayName)
)
)
.limit(1)
if (existingName.length > 0) {
return NextResponse.json({ error: 'Tag name is already in use' }, { status: 409 })
}
// Create the new tag definition
const newTagDefinition = {
id: randomUUID(),
knowledgeBaseId,
tagSlot,
displayName,
fieldType,
createdAt: new Date(),
updatedAt: new Date(),
}
await db.insert(knowledgeBaseTagDefinitions).values(newTagDefinition)
logger.info(`[${requestId}] Successfully created tag definition ${displayName} (${tagSlot})`)
const newTagDefinition = await createTagDefinition(
{
knowledgeBaseId,
tagSlot: validatedData.tagSlot,
displayName: validatedData.displayName,
fieldType: validatedData.fieldType,
},
requestId
)
return NextResponse.json({
success: true,

View File

@@ -1,11 +1,9 @@
import { randomUUID } from 'crypto'
import { and, eq, isNotNull } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { getTagUsage } from '@/lib/knowledge/tags/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
export const dynamic = 'force-dynamic'
@@ -24,57 +22,15 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Get all tag definitions for the knowledge base
const tagDefinitions = await db
.select({
id: knowledgeBaseTagDefinitions.id,
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
})
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
// Get usage statistics for each tag definition
const usageStats = await Promise.all(
tagDefinitions.map(async (tagDef) => {
// Count documents using this tag slot
const tagSlotColumn = tagDef.tagSlot as keyof typeof document.$inferSelect
const documentsWithTag = await db
.select({
id: document.id,
filename: document.filename,
[tagDef.tagSlot]: document[tagSlotColumn as keyof typeof document.$inferSelect] as any,
})
.from(document)
.where(
and(
eq(document.knowledgeBaseId, knowledgeBaseId),
isNotNull(document[tagSlotColumn as keyof typeof document.$inferSelect])
)
)
return {
tagName: tagDef.displayName,
tagSlot: tagDef.tagSlot,
documentCount: documentsWithTag.length,
documents: documentsWithTag.map((doc) => ({
id: doc.id,
name: doc.filename,
tagValue: doc[tagDef.tagSlot],
})),
}
})
)
const usageStats = await getTagUsage(knowledgeBaseId, requestId)
logger.info(
`[${requestId}] Retrieved usage statistics for ${tagDefinitions.length} tag definitions`
`[${requestId}] Retrieved usage statistics for ${usageStats.length} tag definitions`
)
return NextResponse.json({

View File

@@ -1,11 +1,8 @@
import { and, count, eq, isNotNull, isNull, or } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { createKnowledgeBase, getKnowledgeBases } from '@/lib/knowledge/service'
import { createLogger } from '@/lib/logs/console/logger'
import { getUserEntityPermissions } from '@/lib/permissions/utils'
import { db } from '@/db'
import { document, knowledgeBase, permissions } from '@/db/schema'
const logger = createLogger('KnowledgeBaseAPI')
@@ -41,60 +38,10 @@ export async function GET(req: NextRequest) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check for workspace filtering
const { searchParams } = new URL(req.url)
const workspaceId = searchParams.get('workspaceId')
// Get knowledge bases that user can access through direct ownership OR workspace permissions
const knowledgeBasesWithCounts = await db
.select({
id: knowledgeBase.id,
name: knowledgeBase.name,
description: knowledgeBase.description,
tokenCount: knowledgeBase.tokenCount,
embeddingModel: knowledgeBase.embeddingModel,
embeddingDimension: knowledgeBase.embeddingDimension,
chunkingConfig: knowledgeBase.chunkingConfig,
createdAt: knowledgeBase.createdAt,
updatedAt: knowledgeBase.updatedAt,
workspaceId: knowledgeBase.workspaceId,
docCount: count(document.id),
})
.from(knowledgeBase)
.leftJoin(
document,
and(eq(document.knowledgeBaseId, knowledgeBase.id), isNull(document.deletedAt))
)
.leftJoin(
permissions,
and(
eq(permissions.entityType, 'workspace'),
eq(permissions.entityId, knowledgeBase.workspaceId),
eq(permissions.userId, session.user.id)
)
)
.where(
and(
isNull(knowledgeBase.deletedAt),
workspaceId
? // When filtering by workspace
or(
// Knowledge bases belonging to the specified workspace (user must have workspace permissions)
and(eq(knowledgeBase.workspaceId, workspaceId), isNotNull(permissions.userId)),
// Fallback: User-owned knowledge bases without workspace (legacy)
and(eq(knowledgeBase.userId, session.user.id), isNull(knowledgeBase.workspaceId))
)
: // When not filtering by workspace, use original logic
or(
// User owns the knowledge base directly
eq(knowledgeBase.userId, session.user.id),
// User has permissions on the knowledge base's workspace
isNotNull(permissions.userId)
)
)
)
.groupBy(knowledgeBase.id)
.orderBy(knowledgeBase.createdAt)
const knowledgeBasesWithCounts = await getKnowledgeBases(session.user.id, workspaceId)
return NextResponse.json({
success: true,
@@ -121,49 +68,16 @@ export async function POST(req: NextRequest) {
try {
const validatedData = CreateKnowledgeBaseSchema.parse(body)
// If creating in a workspace, check if user has write/admin permissions
if (validatedData.workspaceId) {
const userPermission = await getUserEntityPermissions(
session.user.id,
'workspace',
validatedData.workspaceId
)
if (userPermission !== 'write' && userPermission !== 'admin') {
logger.warn(
`[${requestId}] User ${session.user.id} denied permission to create knowledge base in workspace ${validatedData.workspaceId}`
)
return NextResponse.json(
{ error: 'Insufficient permissions to create knowledge base in this workspace' },
{ status: 403 }
)
}
}
const id = crypto.randomUUID()
const now = new Date()
const newKnowledgeBase = {
id,
const createData = {
...validatedData,
userId: session.user.id,
workspaceId: validatedData.workspaceId || null,
name: validatedData.name,
description: validatedData.description || null,
tokenCount: 0,
embeddingModel: validatedData.embeddingModel,
embeddingDimension: validatedData.embeddingDimension,
chunkingConfig: validatedData.chunkingConfig || {
maxSize: 1024,
minSize: 100,
overlap: 200,
},
docCount: 0,
createdAt: now,
updatedAt: now,
}
await db.insert(knowledgeBase).values(newKnowledgeBase)
const newKnowledgeBase = await createKnowledgeBase(createData, requestId)
logger.info(`[${requestId}] Knowledge base created: ${id} for user ${session.user.id}`)
logger.info(
`[${requestId}] Knowledge base created: ${newKnowledgeBase.id} for user ${session.user.id}`
)
return NextResponse.json({
success: true,

View File

@@ -65,12 +65,14 @@ const mockHandleVectorOnlySearch = vi.fn()
const mockHandleTagAndVectorSearch = vi.fn()
const mockGetQueryStrategy = vi.fn()
const mockGenerateSearchEmbedding = vi.fn()
const mockGetDocumentNamesByIds = vi.fn()
vi.mock('./utils', () => ({
handleTagOnlySearch: mockHandleTagOnlySearch,
handleVectorOnlySearch: mockHandleVectorOnlySearch,
handleTagAndVectorSearch: mockHandleTagAndVectorSearch,
getQueryStrategy: mockGetQueryStrategy,
generateSearchEmbedding: mockGenerateSearchEmbedding,
getDocumentNamesByIds: mockGetDocumentNamesByIds,
APIError: class APIError extends Error {
public status: number
constructor(message: string, status: number) {
@@ -146,6 +148,10 @@ describe('Knowledge Search API Route', () => {
singleQueryOptimized: true,
})
mockGenerateSearchEmbedding.mockClear().mockResolvedValue([0.1, 0.2, 0.3, 0.4, 0.5])
mockGetDocumentNamesByIds.mockClear().mockResolvedValue({
doc1: 'Document 1',
doc2: 'Document 2',
})
vi.stubGlobal('crypto', {
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),

View File

@@ -1,16 +1,15 @@
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { TAG_SLOTS } from '@/lib/constants/knowledge'
import { getDocumentTagDefinitions } from '@/lib/knowledge/tags/service'
import { createLogger } from '@/lib/logs/console/logger'
import { estimateTokenCount } from '@/lib/tokenization/estimators'
import { getUserId } from '@/app/api/auth/oauth/utils'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { knowledgeBaseTagDefinitions } from '@/db/schema'
import { calculateCost } from '@/providers/utils'
import {
generateSearchEmbedding,
getDocumentNamesByIds,
getQueryStrategy,
handleTagAndVectorSearch,
handleTagOnlySearch,
@@ -79,14 +78,13 @@ export async function POST(request: NextRequest) {
? validatedData.knowledgeBaseIds
: [validatedData.knowledgeBaseIds]
// Check access permissions for each knowledge base using proper workspace-based permissions
const accessibleKbIds: string[] = []
for (const kbId of knowledgeBaseIds) {
const accessCheck = await checkKnowledgeBaseAccess(kbId, userId)
if (accessCheck.hasAccess) {
accessibleKbIds.push(kbId)
}
}
// Check access permissions in parallel for performance
const accessChecks = await Promise.all(
knowledgeBaseIds.map((kbId) => checkKnowledgeBaseAccess(kbId, userId))
)
const accessibleKbIds: string[] = knowledgeBaseIds.filter(
(_, idx) => accessChecks[idx]?.hasAccess
)
// Map display names to tag slots for filtering
let mappedFilters: Record<string, string> = {}
@@ -94,13 +92,7 @@ export async function POST(request: NextRequest) {
try {
// Fetch tag definitions for the first accessible KB (since we're using single KB now)
const kbId = accessibleKbIds[0]
const tagDefs = await db
.select({
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
})
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, kbId))
const tagDefs = await getDocumentTagDefinitions(kbId)
logger.debug(`[${requestId}] Found tag definitions:`, tagDefs)
logger.debug(`[${requestId}] Original filters:`, validatedData.filters)
@@ -145,7 +137,10 @@ export async function POST(request: NextRequest) {
// Generate query embedding only if query is provided
const hasQuery = validatedData.query && validatedData.query.trim().length > 0
const queryEmbedding = hasQuery ? await generateSearchEmbedding(validatedData.query!) : null
// Start embedding generation early and await when needed
const queryEmbeddingPromise = hasQuery
? generateSearchEmbedding(validatedData.query!)
: Promise.resolve(null)
// Check if any requested knowledge bases were not accessible
const inaccessibleKbIds = knowledgeBaseIds.filter((id) => !accessibleKbIds.includes(id))
@@ -173,7 +168,7 @@ export async function POST(request: NextRequest) {
// Tag + Vector search
logger.debug(`[${requestId}] Executing tag + vector search with filters:`, mappedFilters)
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
const queryVector = JSON.stringify(queryEmbedding)
const queryVector = JSON.stringify(await queryEmbeddingPromise)
results = await handleTagAndVectorSearch({
knowledgeBaseIds: accessibleKbIds,
@@ -186,7 +181,7 @@ export async function POST(request: NextRequest) {
// Vector-only search
logger.debug(`[${requestId}] Executing vector-only search`)
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
const queryVector = JSON.stringify(queryEmbedding)
const queryVector = JSON.stringify(await queryEmbeddingPromise)
results = await handleVectorOnlySearch({
knowledgeBaseIds: accessibleKbIds,
@@ -221,30 +216,32 @@ export async function POST(request: NextRequest) {
}
// Fetch tag definitions for display name mapping (reuse the same fetch from filtering)
const tagDefinitionsMap: Record<string, Record<string, string>> = {}
for (const kbId of accessibleKbIds) {
try {
const tagDefs = await db
.select({
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
const tagDefsResults = await Promise.all(
accessibleKbIds.map(async (kbId) => {
try {
const tagDefs = await getDocumentTagDefinitions(kbId)
const map: Record<string, string> = {}
tagDefs.forEach((def) => {
map[def.tagSlot] = def.displayName
})
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, kbId))
return { kbId, map }
} catch (error) {
logger.warn(
`[${requestId}] Failed to fetch tag definitions for display mapping:`,
error
)
return { kbId, map: {} as Record<string, string> }
}
})
)
const tagDefinitionsMap: Record<string, Record<string, string>> = {}
tagDefsResults.forEach(({ kbId, map }) => {
tagDefinitionsMap[kbId] = map
})
tagDefinitionsMap[kbId] = {}
tagDefs.forEach((def) => {
tagDefinitionsMap[kbId][def.tagSlot] = def.displayName
})
logger.debug(
`[${requestId}] Display mapping - KB ${kbId} tag definitions:`,
tagDefinitionsMap[kbId]
)
} catch (error) {
logger.warn(`[${requestId}] Failed to fetch tag definitions for display mapping:`, error)
tagDefinitionsMap[kbId] = {}
}
}
// Fetch document names for the results
const documentIds = results.map((result) => result.documentId)
const documentNameMap = await getDocumentNamesByIds(documentIds)
return NextResponse.json({
success: true,
@@ -271,11 +268,11 @@ export async function POST(request: NextRequest) {
})
return {
id: result.id,
content: result.content,
documentId: result.documentId,
documentName: documentNameMap[result.documentId] || undefined,
content: result.content,
chunkIndex: result.chunkIndex,
tags, // Clean display name mapped tags
metadata: tags, // Clean display name mapped tags
similarity: hasQuery ? 1 - result.distance : 1, // Perfect similarity for tag-only searches
}
}),

View File

@@ -16,7 +16,7 @@ vi.mock('@/lib/logs/console/logger', () => ({
})),
}))
vi.mock('@/db')
vi.mock('@/lib/documents/utils', () => ({
vi.mock('@/lib/knowledge/documents/utils', () => ({
retryWithExponentialBackoff: (fn: any) => fn(),
}))

View File

@@ -1,10 +1,34 @@
import { and, eq, inArray, sql } from 'drizzle-orm'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { embedding } from '@/db/schema'
import { document, embedding } from '@/db/schema'
const logger = createLogger('KnowledgeSearchUtils')
export async function getDocumentNamesByIds(
documentIds: string[]
): Promise<Record<string, string>> {
if (documentIds.length === 0) {
return {}
}
const uniqueIds = [...new Set(documentIds)]
const documents = await db
.select({
id: document.id,
filename: document.filename,
})
.from(document)
.where(inArray(document.id, uniqueIds))
const documentNameMap: Record<string, string> = {}
documents.forEach((doc) => {
documentNameMap[doc.id] = doc.filename
})
return documentNameMap
}
export interface SearchResult {
id: string
content: string

View File

@@ -21,11 +21,11 @@ vi.mock('@/lib/env', () => ({
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
}))
vi.mock('@/lib/documents/utils', () => ({
vi.mock('@/lib/knowledge/documents/utils', () => ({
retryWithExponentialBackoff: (fn: any) => fn(),
}))
vi.mock('@/lib/documents/document-processor', () => ({
vi.mock('@/lib/knowledge/documents/document-processor', () => ({
processDocument: vi.fn().mockResolvedValue({
chunks: [
{
@@ -149,12 +149,12 @@ vi.mock('@/db', () => {
}
})
import { generateEmbeddings } from '@/lib/embeddings/utils'
import { processDocumentAsync } from '@/lib/knowledge/documents/service'
import {
checkChunkAccess,
checkDocumentAccess,
checkKnowledgeBaseAccess,
generateEmbeddings,
processDocumentAsync,
} from '@/app/api/knowledge/utils'
describe('Knowledge Utils', () => {

View File

@@ -1,35 +1,8 @@
import crypto from 'crypto'
import { and, eq, isNull } from 'drizzle-orm'
import { processDocument } from '@/lib/documents/document-processor'
import { generateEmbeddings } from '@/lib/embeddings/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { getUserEntityPermissions } from '@/lib/permissions/utils'
import { db } from '@/db'
import { document, embedding, knowledgeBase } from '@/db/schema'
const logger = createLogger('KnowledgeUtils')
const TIMEOUTS = {
OVERALL_PROCESSING: 150000, // 150 seconds (2.5 minutes)
EMBEDDINGS_API: 60000, // 60 seconds per batch
} as const
/**
* Create a timeout wrapper for async operations
*/
function withTimeout<T>(
promise: Promise<T>,
timeoutMs: number,
operation = 'Operation'
): Promise<T> {
return Promise.race([
promise,
new Promise<never>((_, reject) =>
setTimeout(() => reject(new Error(`${operation} timed out after ${timeoutMs}ms`)), timeoutMs)
),
])
}
export interface KnowledgeBaseData {
id: string
userId: string
@@ -380,154 +353,3 @@ export async function checkChunkAccess(
knowledgeBase: kbAccess.knowledgeBase!,
}
}
// Export for external use
export { generateEmbeddings }
/**
* Process a document asynchronously with full error handling
*/
export async function processDocumentAsync(
knowledgeBaseId: string,
documentId: string,
docData: {
filename: string
fileUrl: string
fileSize: number
mimeType: string
},
processingOptions: {
chunkSize?: number
minCharactersPerChunk?: number
recipe?: string
lang?: string
chunkOverlap?: number
}
): Promise<void> {
const startTime = Date.now()
try {
logger.info(`[${documentId}] Starting document processing: ${docData.filename}`)
// Set status to processing
await db
.update(document)
.set({
processingStatus: 'processing',
processingStartedAt: new Date(),
processingError: null, // Clear any previous error
})
.where(eq(document.id, documentId))
logger.info(`[${documentId}] Status updated to 'processing', starting document processor`)
// Wrap the entire processing operation with a 5-minute timeout
await withTimeout(
(async () => {
const processed = await processDocument(
docData.fileUrl,
docData.filename,
docData.mimeType,
processingOptions.chunkSize || 1000,
processingOptions.chunkOverlap || 200,
processingOptions.minCharactersPerChunk || 1
)
const now = new Date()
logger.info(
`[${documentId}] Document parsed successfully, generating embeddings for ${processed.chunks.length} chunks`
)
const chunkTexts = processed.chunks.map((chunk) => chunk.text)
const embeddings = chunkTexts.length > 0 ? await generateEmbeddings(chunkTexts) : []
logger.info(`[${documentId}] Embeddings generated, fetching document tags`)
// Fetch document to get tags
const documentRecord = await db
.select({
tag1: document.tag1,
tag2: document.tag2,
tag3: document.tag3,
tag4: document.tag4,
tag5: document.tag5,
tag6: document.tag6,
tag7: document.tag7,
})
.from(document)
.where(eq(document.id, documentId))
.limit(1)
const documentTags = documentRecord[0] || {}
logger.info(`[${documentId}] Creating embedding records with tags`)
const embeddingRecords = processed.chunks.map((chunk, chunkIndex) => ({
id: crypto.randomUUID(),
knowledgeBaseId,
documentId,
chunkIndex,
chunkHash: crypto.createHash('sha256').update(chunk.text).digest('hex'),
content: chunk.text,
contentLength: chunk.text.length,
tokenCount: Math.ceil(chunk.text.length / 4),
embedding: embeddings[chunkIndex] || null,
embeddingModel: 'text-embedding-3-small',
startOffset: chunk.metadata.startIndex,
endOffset: chunk.metadata.endIndex,
// Copy tags from document
tag1: documentTags.tag1,
tag2: documentTags.tag2,
tag3: documentTags.tag3,
tag4: documentTags.tag4,
tag5: documentTags.tag5,
tag6: documentTags.tag6,
tag7: documentTags.tag7,
createdAt: now,
updatedAt: now,
}))
await db.transaction(async (tx) => {
if (embeddingRecords.length > 0) {
await tx.insert(embedding).values(embeddingRecords)
}
await tx
.update(document)
.set({
chunkCount: processed.metadata.chunkCount,
tokenCount: processed.metadata.tokenCount,
characterCount: processed.metadata.characterCount,
processingStatus: 'completed',
processingCompletedAt: now,
processingError: null,
})
.where(eq(document.id, documentId))
})
})(),
TIMEOUTS.OVERALL_PROCESSING,
'Document processing'
)
const processingTime = Date.now() - startTime
logger.info(`[${documentId}] Successfully processed document in ${processingTime}ms`)
} catch (error) {
const processingTime = Date.now() - startTime
logger.error(`[${documentId}] Failed to process document after ${processingTime}ms:`, {
error: error instanceof Error ? error.message : 'Unknown error',
stack: error instanceof Error ? error.stack : undefined,
filename: docData.filename,
fileUrl: docData.fileUrl,
mimeType: docData.mimeType,
})
await db
.update(document)
.set({
processingStatus: 'failed',
processingError: error instanceof Error ? error.message : 'Unknown error',
processingCompletedAt: new Date(),
})
.where(eq(document.id, documentId))
}
}

View File

@@ -73,30 +73,59 @@ export async function GET(request: NextRequest) {
const { searchParams } = new URL(request.url)
const params = QueryParamsSchema.parse(Object.fromEntries(searchParams.entries()))
// Conditionally select columns based on detail level to optimize performance
const selectColumns =
params.details === 'full'
? {
id: workflowExecutionLogs.id,
workflowId: workflowExecutionLogs.workflowId,
executionId: workflowExecutionLogs.executionId,
stateSnapshotId: workflowExecutionLogs.stateSnapshotId,
level: workflowExecutionLogs.level,
trigger: workflowExecutionLogs.trigger,
startedAt: workflowExecutionLogs.startedAt,
endedAt: workflowExecutionLogs.endedAt,
totalDurationMs: workflowExecutionLogs.totalDurationMs,
executionData: workflowExecutionLogs.executionData, // Large field - only in full mode
cost: workflowExecutionLogs.cost,
files: workflowExecutionLogs.files, // Large field - only in full mode
createdAt: workflowExecutionLogs.createdAt,
workflowName: workflow.name,
workflowDescription: workflow.description,
workflowColor: workflow.color,
workflowFolderId: workflow.folderId,
workflowUserId: workflow.userId,
workflowWorkspaceId: workflow.workspaceId,
workflowCreatedAt: workflow.createdAt,
workflowUpdatedAt: workflow.updatedAt,
}
: {
// Basic mode - exclude large fields for better performance
id: workflowExecutionLogs.id,
workflowId: workflowExecutionLogs.workflowId,
executionId: workflowExecutionLogs.executionId,
stateSnapshotId: workflowExecutionLogs.stateSnapshotId,
level: workflowExecutionLogs.level,
trigger: workflowExecutionLogs.trigger,
startedAt: workflowExecutionLogs.startedAt,
endedAt: workflowExecutionLogs.endedAt,
totalDurationMs: workflowExecutionLogs.totalDurationMs,
executionData: sql<null>`NULL`, // Exclude large execution data in basic mode
cost: workflowExecutionLogs.cost,
files: sql<null>`NULL`, // Exclude files in basic mode
createdAt: workflowExecutionLogs.createdAt,
workflowName: workflow.name,
workflowDescription: workflow.description,
workflowColor: workflow.color,
workflowFolderId: workflow.folderId,
workflowUserId: workflow.userId,
workflowWorkspaceId: workflow.workspaceId,
workflowCreatedAt: workflow.createdAt,
workflowUpdatedAt: workflow.updatedAt,
}
const baseQuery = db
.select({
id: workflowExecutionLogs.id,
workflowId: workflowExecutionLogs.workflowId,
executionId: workflowExecutionLogs.executionId,
stateSnapshotId: workflowExecutionLogs.stateSnapshotId,
level: workflowExecutionLogs.level,
trigger: workflowExecutionLogs.trigger,
startedAt: workflowExecutionLogs.startedAt,
endedAt: workflowExecutionLogs.endedAt,
totalDurationMs: workflowExecutionLogs.totalDurationMs,
executionData: workflowExecutionLogs.executionData,
cost: workflowExecutionLogs.cost,
files: workflowExecutionLogs.files,
createdAt: workflowExecutionLogs.createdAt,
workflowName: workflow.name,
workflowDescription: workflow.description,
workflowColor: workflow.color,
workflowFolderId: workflow.folderId,
workflowUserId: workflow.userId,
workflowWorkspaceId: workflow.workspaceId,
workflowCreatedAt: workflow.createdAt,
workflowUpdatedAt: workflow.updatedAt,
})
.select(selectColumns)
.from(workflowExecutionLogs)
.innerJoin(workflow, eq(workflowExecutionLogs.workflowId, workflow.id))
.innerJoin(
@@ -276,18 +305,24 @@ export async function GET(request: NextRequest) {
const enhancedLogs = logs.map((log) => {
const blockExecutions = blockExecutionsByExecution[log.executionId] || []
// Use stored trace spans if available, otherwise create from block executions
const storedTraceSpans = (log.executionData as any)?.traceSpans
const traceSpans =
storedTraceSpans && Array.isArray(storedTraceSpans) && storedTraceSpans.length > 0
? storedTraceSpans
: createTraceSpans(blockExecutions)
// Only process trace spans and detailed cost in full mode
let traceSpans = []
let costSummary = (log.cost as any) || { total: 0 }
// Prefer stored cost JSON; otherwise synthesize from blocks
const costSummary =
log.cost && Object.keys(log.cost as any).length > 0
? (log.cost as any)
: extractCostSummary(blockExecutions)
if (params.details === 'full' && log.executionData) {
// Use stored trace spans if available, otherwise create from block executions
const storedTraceSpans = (log.executionData as any)?.traceSpans
traceSpans =
storedTraceSpans && Array.isArray(storedTraceSpans) && storedTraceSpans.length > 0
? storedTraceSpans
: createTraceSpans(blockExecutions)
// Prefer stored cost JSON; otherwise synthesize from blocks
costSummary =
log.cost && Object.keys(log.cost as any).length > 0
? (log.cost as any)
: extractCostSummary(blockExecutions)
}
const workflowSummary = {
id: log.workflowId,

View File

@@ -1,6 +1,7 @@
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { getUserUsageData } from '@/lib/billing/core/usage'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { member, user, userStats } from '@/db/schema'
@@ -80,9 +81,6 @@ export async function GET(
.select({
currentPeriodCost: userStats.currentPeriodCost,
currentUsageLimit: userStats.currentUsageLimit,
billingPeriodStart: userStats.billingPeriodStart,
billingPeriodEnd: userStats.billingPeriodEnd,
usageLimitSetBy: userStats.usageLimitSetBy,
usageLimitUpdatedAt: userStats.usageLimitUpdatedAt,
lastPeriodCost: userStats.lastPeriodCost,
})
@@ -90,11 +88,22 @@ export async function GET(
.where(eq(userStats.userId, memberId))
.limit(1)
const computed = await getUserUsageData(memberId)
if (usageData.length > 0) {
memberData = {
...memberData,
usage: usageData[0],
} as typeof memberData & { usage: (typeof usageData)[0] }
usage: {
...usageData[0],
billingPeriodStart: computed.billingPeriodStart,
billingPeriodEnd: computed.billingPeriodEnd,
},
} as typeof memberData & {
usage: (typeof usageData)[0] & {
billingPeriodStart: Date | null
billingPeriodEnd: Date | null
}
}
}
}
@@ -180,6 +189,11 @@ export async function PUT(
)
}
// Prevent admins from changing other admins' roles - only owners can modify admin roles
if (targetMember[0].role === 'admin' && userMember[0].role !== 'owner') {
return NextResponse.json({ error: 'Only owners can change admin roles' }, { status: 403 })
}
// Update member role
const updatedMember = await db
.update(member)

View File

@@ -3,6 +3,7 @@ import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getEmailSubject, renderInvitationEmail } from '@/components/emails/render-email'
import { getSession } from '@/lib/auth'
import { getUserUsageData } from '@/lib/billing/core/usage'
import { validateSeatAvailability } from '@/lib/billing/validation/seat-management'
import { sendEmail } from '@/lib/email/mailer'
import { quickValidateEmail } from '@/lib/email/validation'
@@ -63,7 +64,7 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
// Include usage data if requested and user has admin access
if (includeUsage && hasAdminAccess) {
const membersWithUsage = await db
const base = await db
.select({
id: member.id,
userId: member.userId,
@@ -74,9 +75,6 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
userEmail: user.email,
currentPeriodCost: userStats.currentPeriodCost,
currentUsageLimit: userStats.currentUsageLimit,
billingPeriodStart: userStats.billingPeriodStart,
billingPeriodEnd: userStats.billingPeriodEnd,
usageLimitSetBy: userStats.usageLimitSetBy,
usageLimitUpdatedAt: userStats.usageLimitUpdatedAt,
})
.from(member)
@@ -84,6 +82,17 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
.leftJoin(userStats, eq(user.id, userStats.userId))
.where(eq(member.organizationId, organizationId))
const membersWithUsage = await Promise.all(
base.map(async (row) => {
const usage = await getUserUsageData(row.userId)
return {
...row,
billingPeriodStart: usage.billingPeriodStart,
billingPeriodEnd: usage.billingPeriodEnd,
}
})
)
return NextResponse.json({
success: true,
data: membersWithUsage,

View File

@@ -5,7 +5,7 @@ import { getSession } from '@/lib/auth'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { invitation, member, permissions, workspaceInvitation } from '@/db/schema'
import { invitation, member, permissions, user, workspaceInvitation } from '@/db/schema'
const logger = createLogger('OrganizationInvitationAcceptanceAPI')
@@ -70,11 +70,33 @@ export async function GET(req: NextRequest) {
)
}
// Get user data to check email verification status
const userData = await db.select().from(user).where(eq(user.id, session.user.id)).limit(1)
if (userData.length === 0) {
return NextResponse.redirect(
new URL(
'/invite/invite-error?reason=user-not-found',
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
// Check if user's email is verified
if (!userData[0].emailVerified) {
return NextResponse.redirect(
new URL(
`/invite/invite-error?reason=email-not-verified&details=${encodeURIComponent(`You must verify your email address (${userData[0].email}) before accepting invitations.`)}`,
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
// Verify the email matches the current user
if (orgInvitation.email !== session.user.email) {
return NextResponse.redirect(
new URL(
'/invite/invite-error?reason=email-mismatch',
`/invite/invite-error?reason=email-mismatch&details=${encodeURIComponent(`Invitation was sent to ${orgInvitation.email}, but you're logged in as ${userData[0].email}`)}`,
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
@@ -235,6 +257,24 @@ export async function POST(req: NextRequest) {
return NextResponse.json({ error: 'Invitation already processed' }, { status: 400 })
}
// Get user data to check email verification status
const userData = await db.select().from(user).where(eq(user.id, session.user.id)).limit(1)
if (userData.length === 0) {
return NextResponse.json({ error: 'User not found' }, { status: 404 })
}
// Check if user's email is verified
if (!userData[0].emailVerified) {
return NextResponse.json(
{
error: 'Email not verified',
message: `You must verify your email address (${userData[0].email}) before accepting invitations.`,
},
{ status: 403 }
)
}
if (orgInvitation.email !== session.user.email) {
return NextResponse.json({ error: 'Email mismatch' }, { status: 403 })
}

View File

@@ -0,0 +1,73 @@
import { NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { createOrganizationForTeamPlan } from '@/lib/billing/organization'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('CreateTeamOrganization')
export async function POST(request: Request) {
try {
const session = await getSession()
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized - no active session' }, { status: 401 })
}
const user = session.user
// Parse request body for optional name and slug
let organizationName = user.name
let organizationSlug: string | undefined
try {
const body = await request.json()
if (body.name && typeof body.name === 'string') {
organizationName = body.name
}
if (body.slug && typeof body.slug === 'string') {
organizationSlug = body.slug
}
} catch {
// If no body or invalid JSON, use defaults
}
logger.info('Creating organization for team plan', {
userId: user.id,
userName: user.name,
userEmail: user.email,
organizationName,
organizationSlug,
})
// Create organization and make user the owner/admin
const organizationId = await createOrganizationForTeamPlan(
user.id,
organizationName || undefined,
user.email,
organizationSlug
)
logger.info('Successfully created organization for team plan', {
userId: user.id,
organizationId,
})
return NextResponse.json({
success: true,
organizationId,
})
} catch (error) {
logger.error('Failed to create organization for team plan', {
error: error instanceof Error ? error.message : 'Unknown error',
stack: error instanceof Error ? error.stack : undefined,
})
return NextResponse.json(
{
error: 'Failed to create organization',
message: error instanceof Error ? error.message : 'Unknown error',
},
{ status: 500 }
)
}
}

View File

@@ -0,0 +1,46 @@
import { type NextRequest, NextResponse } from 'next/server'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('OpenRouterModelsAPI')
export const dynamic = 'force-dynamic'
export async function GET(_request: NextRequest) {
try {
const response = await fetch('https://openrouter.ai/api/v1/models', {
headers: { 'Content-Type': 'application/json' },
cache: 'no-store',
})
if (!response.ok) {
logger.warn('Failed to fetch OpenRouter models', {
status: response.status,
statusText: response.statusText,
})
return NextResponse.json({ models: [] })
}
const data = await response.json()
const models = Array.isArray(data?.data)
? Array.from(
new Set(
data.data
.map((m: any) => m?.id)
.filter((id: unknown): id is string => typeof id === 'string' && id.length > 0)
.map((id: string) => `openrouter/${id}`)
)
)
: []
logger.info('Successfully fetched OpenRouter models', {
count: models.length,
})
return NextResponse.json({ models })
} catch (error) {
logger.error('Error fetching OpenRouter models', {
error: error instanceof Error ? error.message : 'Unknown error',
})
return NextResponse.json({ models: [] })
}
}

View File

@@ -39,6 +39,9 @@ export async function POST(request: NextRequest) {
stream,
messages,
environmentVariables,
workflowVariables,
blockData,
blockNameMapping,
reasoningEffort,
verbosity,
} = body
@@ -60,6 +63,7 @@ export async function POST(request: NextRequest) {
messageCount: messages?.length || 0,
hasEnvironmentVariables:
!!environmentVariables && Object.keys(environmentVariables).length > 0,
hasWorkflowVariables: !!workflowVariables && Object.keys(workflowVariables).length > 0,
reasoningEffort,
verbosity,
})
@@ -103,6 +107,9 @@ export async function POST(request: NextRequest) {
stream,
messages,
environmentVariables,
workflowVariables,
blockData,
blockNameMapping,
reasoningEffort,
verbosity,
})

View File

@@ -1,5 +1,6 @@
import { type NextRequest, NextResponse } from 'next/server'
import { createLogger } from '@/lib/logs/console/logger'
import { validateImageUrl } from '@/lib/security/url-validation'
const logger = createLogger('ImageProxyAPI')
@@ -17,10 +18,18 @@ export async function GET(request: NextRequest) {
return new NextResponse('Missing URL parameter', { status: 400 })
}
const urlValidation = validateImageUrl(imageUrl)
if (!urlValidation.isValid) {
logger.warn(`[${requestId}] Blocked image proxy request`, {
url: imageUrl.substring(0, 100),
error: urlValidation.error,
})
return new NextResponse(urlValidation.error || 'Invalid image URL', { status: 403 })
}
logger.info(`[${requestId}] Proxying image request for: ${imageUrl}`)
try {
// Use fetch with custom headers that appear more browser-like
const imageResponse = await fetch(imageUrl, {
headers: {
'User-Agent':
@@ -45,10 +54,8 @@ export async function GET(request: NextRequest) {
})
}
// Get image content type from response headers
const contentType = imageResponse.headers.get('content-type') || 'image/jpeg'
// Get the image as a blob
const imageBlob = await imageResponse.blob()
if (imageBlob.size === 0) {
@@ -56,7 +63,6 @@ export async function GET(request: NextRequest) {
return new NextResponse('Empty image received', { status: 404 })
}
// Return the image with appropriate headers
return new NextResponse(imageBlob, {
headers: {
'Content-Type': contentType,

View File

@@ -1,6 +1,7 @@
import { NextResponse } from 'next/server'
import { isDev } from '@/lib/environment'
import { createLogger } from '@/lib/logs/console/logger'
import { validateProxyUrl } from '@/lib/security/url-validation'
import { executeTool } from '@/tools'
import { getTool, validateRequiredParametersAfterMerge } from '@/tools/utils'
@@ -80,6 +81,15 @@ export async function GET(request: Request) {
return createErrorResponse("Missing 'url' parameter", 400)
}
const urlValidation = validateProxyUrl(targetUrl)
if (!urlValidation.isValid) {
logger.warn(`[${requestId}] Blocked proxy request`, {
url: targetUrl.substring(0, 100),
error: urlValidation.error,
})
return createErrorResponse(urlValidation.error || 'Invalid URL', 403)
}
const method = url.searchParams.get('method') || 'GET'
const bodyParam = url.searchParams.get('body')
@@ -109,7 +119,6 @@ export async function GET(request: Request) {
logger.info(`[${requestId}] Proxying ${method} request to: ${targetUrl}`)
try {
// Forward the request to the target URL with all specified headers
const response = await fetch(targetUrl, {
method: method,
headers: {
@@ -119,7 +128,6 @@ export async function GET(request: Request) {
body: body || undefined,
})
// Get response data
const contentType = response.headers.get('content-type') || ''
let data
@@ -129,7 +137,6 @@ export async function GET(request: Request) {
data = await response.text()
}
// For error responses, include a more descriptive error message
const errorMessage = !response.ok
? data && typeof data === 'object' && data.error
? `${data.error.message || JSON.stringify(data.error)}`
@@ -140,7 +147,6 @@ export async function GET(request: Request) {
logger.error(`[${requestId}] External API error: ${response.status} ${response.statusText}`)
}
// Return the proxied response
return formatResponse({
success: response.ok,
status: response.status,
@@ -166,7 +172,6 @@ export async function POST(request: Request) {
const startTimeISO = startTime.toISOString()
try {
// Parse request body
let requestBody
try {
requestBody = await request.json()
@@ -186,7 +191,6 @@ export async function POST(request: Request) {
logger.info(`[${requestId}] Processing tool: ${toolId}`)
// Get tool
const tool = getTool(toolId)
if (!tool) {
@@ -194,7 +198,6 @@ export async function POST(request: Request) {
throw new Error(`Tool not found: ${toolId}`)
}
// Validate the tool and its parameters
try {
validateRequiredParametersAfterMerge(toolId, tool, params)
} catch (validationError) {
@@ -202,7 +205,6 @@ export async function POST(request: Request) {
error: validationError instanceof Error ? validationError.message : String(validationError),
})
// Add timing information even to error responses
const endTime = new Date()
const endTimeISO = endTime.toISOString()
const duration = endTime.getTime() - startTime.getTime()
@@ -214,14 +216,12 @@ export async function POST(request: Request) {
})
}
// Check if tool has file outputs - if so, don't skip post-processing
const hasFileOutputs =
tool.outputs &&
Object.values(tool.outputs).some(
(output) => output.type === 'file' || output.type === 'file[]'
)
// Execute tool
const result = await executeTool(
toolId,
params,

View File

@@ -64,7 +64,9 @@ export async function POST(request: Request) {
return new NextResponse(
`Internal Server Error: ${error instanceof Error ? error.message : 'Unknown error'}`,
{ status: 500 }
{
status: 500,
}
)
}
}

View File

@@ -112,7 +112,9 @@ export async function POST(request: NextRequest) {
return new Response(
`Internal Server Error: ${error instanceof Error ? error.message : 'Unknown error'}`,
{ status: 500 }
{
status: 500,
}
)
}
}

View File

@@ -474,8 +474,10 @@ export async function GET() {
})
await loggingSession.safeCompleteWithError({
message: `Schedule execution failed before workflow started: ${earlyError.message}`,
stackTrace: earlyError.stack,
error: {
message: `Schedule execution failed before workflow started: ${earlyError.message}`,
stackTrace: earlyError.stack,
},
})
} catch (loggingError) {
logger.error(
@@ -591,8 +593,10 @@ export async function GET() {
})
await failureLoggingSession.safeCompleteWithError({
message: `Schedule execution failed: ${error.message}`,
stackTrace: error.stack,
error: {
message: `Schedule execution failed: ${error.message}`,
stackTrace: error.stack,
},
})
} catch (loggingError) {
logger.error(

View File

@@ -0,0 +1,68 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { buildDeleteQuery, createMySQLConnection, executeQuery } from '@/app/api/tools/mysql/utils'
const logger = createLogger('MySQLDeleteAPI')
const DeleteSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
table: z.string().min(1, 'Table name is required'),
where: z.string().min(1, 'WHERE clause is required'),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
const params = DeleteSchema.parse(body)
logger.info(
`[${requestId}] Deleting data from ${params.table} on ${params.host}:${params.port}/${params.database}`
)
const connection = await createMySQLConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
ssl: params.ssl,
})
try {
const { query, values } = buildDeleteQuery(params.table, params.where)
const result = await executeQuery(connection, query, values)
logger.info(`[${requestId}] Delete executed successfully, ${result.rowCount} row(s) deleted`)
return NextResponse.json({
message: `Data deleted successfully. ${result.rowCount} row(s) affected.`,
rows: result.rows,
rowCount: result.rowCount,
})
} finally {
await connection.end()
}
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] MySQL delete failed:`, error)
return NextResponse.json({ error: `MySQL delete failed: ${errorMessage}` }, { status: 500 })
}
}

View File

@@ -0,0 +1,75 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createMySQLConnection, executeQuery, validateQuery } from '@/app/api/tools/mysql/utils'
const logger = createLogger('MySQLExecuteAPI')
const ExecuteSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
query: z.string().min(1, 'Query is required'),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
const params = ExecuteSchema.parse(body)
logger.info(
`[${requestId}] Executing raw SQL on ${params.host}:${params.port}/${params.database}`
)
const validation = validateQuery(params.query)
if (!validation.isValid) {
logger.warn(`[${requestId}] Query validation failed: ${validation.error}`)
return NextResponse.json(
{ error: `Query validation failed: ${validation.error}` },
{ status: 400 }
)
}
const connection = await createMySQLConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
ssl: params.ssl,
})
try {
const result = await executeQuery(connection, params.query)
logger.info(`[${requestId}] SQL executed successfully, ${result.rowCount} row(s) affected`)
return NextResponse.json({
message: `SQL executed successfully. ${result.rowCount} row(s) affected.`,
rows: result.rows,
rowCount: result.rowCount,
})
} finally {
await connection.end()
}
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] MySQL execute failed:`, error)
return NextResponse.json({ error: `MySQL execute failed: ${errorMessage}` }, { status: 500 })
}
}

View File

@@ -0,0 +1,89 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { buildInsertQuery, createMySQLConnection, executeQuery } from '@/app/api/tools/mysql/utils'
const logger = createLogger('MySQLInsertAPI')
const InsertSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
table: z.string().min(1, 'Table name is required'),
data: z.union([
z
.record(z.unknown())
.refine((obj) => Object.keys(obj).length > 0, 'Data object cannot be empty'),
z
.string()
.min(1)
.transform((str) => {
try {
const parsed = JSON.parse(str)
if (typeof parsed !== 'object' || parsed === null || Array.isArray(parsed)) {
throw new Error('Data must be a JSON object')
}
return parsed
} catch (e) {
const errorMsg = e instanceof Error ? e.message : 'Unknown error'
throw new Error(
`Invalid JSON format in data field: ${errorMsg}. Received: ${str.substring(0, 100)}...`
)
}
}),
]),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
const params = InsertSchema.parse(body)
logger.info(
`[${requestId}] Inserting data into ${params.table} on ${params.host}:${params.port}/${params.database}`
)
const connection = await createMySQLConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
ssl: params.ssl,
})
try {
const { query, values } = buildInsertQuery(params.table, params.data)
const result = await executeQuery(connection, query, values)
logger.info(`[${requestId}] Insert executed successfully, ${result.rowCount} row(s) inserted`)
return NextResponse.json({
message: `Data inserted successfully. ${result.rowCount} row(s) affected.`,
rows: result.rows,
rowCount: result.rowCount,
})
} finally {
await connection.end()
}
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] MySQL insert failed:`, error)
return NextResponse.json({ error: `MySQL insert failed: ${errorMessage}` }, { status: 500 })
}
}

View File

@@ -0,0 +1,75 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createMySQLConnection, executeQuery, validateQuery } from '@/app/api/tools/mysql/utils'
const logger = createLogger('MySQLQueryAPI')
const QuerySchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
query: z.string().min(1, 'Query is required'),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
const params = QuerySchema.parse(body)
logger.info(
`[${requestId}] Executing MySQL query on ${params.host}:${params.port}/${params.database}`
)
const validation = validateQuery(params.query)
if (!validation.isValid) {
logger.warn(`[${requestId}] Query validation failed: ${validation.error}`)
return NextResponse.json(
{ error: `Query validation failed: ${validation.error}` },
{ status: 400 }
)
}
const connection = await createMySQLConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
ssl: params.ssl,
})
try {
const result = await executeQuery(connection, params.query)
logger.info(`[${requestId}] Query executed successfully, returned ${result.rowCount} rows`)
return NextResponse.json({
message: `Query executed successfully. ${result.rowCount} row(s) returned.`,
rows: result.rows,
rowCount: result.rowCount,
})
} finally {
await connection.end()
}
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] MySQL query failed:`, error)
return NextResponse.json({ error: `MySQL query failed: ${errorMessage}` }, { status: 500 })
}
}

View File

@@ -0,0 +1,87 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { buildUpdateQuery, createMySQLConnection, executeQuery } from '@/app/api/tools/mysql/utils'
const logger = createLogger('MySQLUpdateAPI')
const UpdateSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
table: z.string().min(1, 'Table name is required'),
data: z.union([
z
.record(z.unknown())
.refine((obj) => Object.keys(obj).length > 0, 'Data object cannot be empty'),
z
.string()
.min(1)
.transform((str) => {
try {
const parsed = JSON.parse(str)
if (typeof parsed !== 'object' || parsed === null || Array.isArray(parsed)) {
throw new Error('Data must be a JSON object')
}
return parsed
} catch (e) {
throw new Error('Invalid JSON format in data field')
}
}),
]),
where: z.string().min(1, 'WHERE clause is required'),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
const params = UpdateSchema.parse(body)
logger.info(
`[${requestId}] Updating data in ${params.table} on ${params.host}:${params.port}/${params.database}`
)
const connection = await createMySQLConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
ssl: params.ssl,
})
try {
const { query, values } = buildUpdateQuery(params.table, params.data, params.where)
const result = await executeQuery(connection, query, values)
logger.info(`[${requestId}] Update executed successfully, ${result.rowCount} row(s) updated`)
return NextResponse.json({
message: `Data updated successfully. ${result.rowCount} row(s) affected.`,
rows: result.rows,
rowCount: result.rowCount,
})
} finally {
await connection.end()
}
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] MySQL update failed:`, error)
return NextResponse.json({ error: `MySQL update failed: ${errorMessage}` }, { status: 500 })
}
}

View File

@@ -0,0 +1,173 @@
import mysql from 'mysql2/promise'
export interface MySQLConnectionConfig {
host: string
port: number
database: string
username: string
password: string
ssl?: string
}
export async function createMySQLConnection(config: MySQLConnectionConfig) {
const connectionConfig: mysql.ConnectionOptions = {
host: config.host,
port: config.port,
database: config.database,
user: config.username,
password: config.password,
}
if (config.ssl === 'required') {
connectionConfig.ssl = { rejectUnauthorized: true }
} else if (config.ssl === 'preferred') {
connectionConfig.ssl = { rejectUnauthorized: false }
}
return mysql.createConnection(connectionConfig)
}
export async function executeQuery(
connection: mysql.Connection,
query: string,
values?: unknown[]
) {
const [rows, fields] = await connection.execute(query, values)
if (Array.isArray(rows)) {
return {
rows: rows as unknown[],
rowCount: rows.length,
fields,
}
}
return {
rows: [],
rowCount: (rows as mysql.ResultSetHeader).affectedRows || 0,
fields,
}
}
export function validateQuery(query: string): { isValid: boolean; error?: string } {
const trimmedQuery = query.trim().toLowerCase()
const dangerousPatterns = [
/drop\s+database/i,
/drop\s+schema/i,
/drop\s+user/i,
/create\s+user/i,
/grant\s+/i,
/revoke\s+/i,
/alter\s+user/i,
/set\s+global/i,
/set\s+session/i,
/load\s+data/i,
/into\s+outfile/i,
/into\s+dumpfile/i,
/load_file\s*\(/i,
/system\s+/i,
/exec\s+/i,
/execute\s+immediate/i,
/xp_cmdshell/i,
/sp_configure/i,
/information_schema\.tables/i,
/mysql\.user/i,
/mysql\.db/i,
/mysql\.host/i,
/performance_schema/i,
/sys\./i,
]
for (const pattern of dangerousPatterns) {
if (pattern.test(query)) {
return {
isValid: false,
error: `Query contains potentially dangerous operation: ${pattern.source}`,
}
}
}
const allowedStatements = /^(select|insert|update|delete|with|show|describe|explain)\s+/i
if (!allowedStatements.test(trimmedQuery)) {
return {
isValid: false,
error:
'Only SELECT, INSERT, UPDATE, DELETE, WITH, SHOW, DESCRIBE, and EXPLAIN statements are allowed',
}
}
return { isValid: true }
}
export function buildInsertQuery(table: string, data: Record<string, unknown>) {
const sanitizedTable = sanitizeIdentifier(table)
const columns = Object.keys(data)
const values = Object.values(data)
const placeholders = columns.map(() => '?').join(', ')
const query = `INSERT INTO ${sanitizedTable} (${columns.map(sanitizeIdentifier).join(', ')}) VALUES (${placeholders})`
return { query, values }
}
export function buildUpdateQuery(table: string, data: Record<string, unknown>, where: string) {
validateWhereClause(where)
const sanitizedTable = sanitizeIdentifier(table)
const columns = Object.keys(data)
const values = Object.values(data)
const setClause = columns.map((col) => `${sanitizeIdentifier(col)} = ?`).join(', ')
const query = `UPDATE ${sanitizedTable} SET ${setClause} WHERE ${where}`
return { query, values }
}
export function buildDeleteQuery(table: string, where: string) {
validateWhereClause(where)
const sanitizedTable = sanitizeIdentifier(table)
const query = `DELETE FROM ${sanitizedTable} WHERE ${where}`
return { query, values: [] }
}
function validateWhereClause(where: string): void {
const dangerousPatterns = [
/;\s*(drop|delete|insert|update|create|alter|grant|revoke)/i,
/union\s+select/i,
/into\s+outfile/i,
/load_file/i,
/--/,
/\/\*/,
/\*\//,
]
for (const pattern of dangerousPatterns) {
if (pattern.test(where)) {
throw new Error('WHERE clause contains potentially dangerous operation')
}
}
}
export function sanitizeIdentifier(identifier: string): string {
if (identifier.includes('.')) {
const parts = identifier.split('.')
return parts.map((part) => sanitizeSingleIdentifier(part)).join('.')
}
return sanitizeSingleIdentifier(identifier)
}
function sanitizeSingleIdentifier(identifier: string): string {
const cleaned = identifier.replace(/`/g, '')
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(cleaned)) {
throw new Error(
`Invalid identifier: ${identifier}. Identifiers must start with a letter or underscore and contain only letters, numbers, and underscores.`
)
}
return `\`${cleaned}\``
}

View File

@@ -0,0 +1,70 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createPostgresConnection, executeDelete } from '@/app/api/tools/postgresql/utils'
const logger = createLogger('PostgreSQLDeleteAPI')
const DeleteSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
table: z.string().min(1, 'Table name is required'),
where: z.string().min(1, 'WHERE clause is required'),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
const params = DeleteSchema.parse(body)
logger.info(
`[${requestId}] Deleting data from ${params.table} on ${params.host}:${params.port}/${params.database}`
)
const sql = createPostgresConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
ssl: params.ssl,
})
try {
const result = await executeDelete(sql, params.table, params.where)
logger.info(`[${requestId}] Delete executed successfully, ${result.rowCount} row(s) deleted`)
return NextResponse.json({
message: `Data deleted successfully. ${result.rowCount} row(s) affected.`,
rows: result.rows,
rowCount: result.rowCount,
})
} finally {
await sql.end()
}
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] PostgreSQL delete failed:`, error)
return NextResponse.json(
{ error: `PostgreSQL delete failed: ${errorMessage}` },
{ status: 500 }
)
}
}

View File

@@ -0,0 +1,82 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import {
createPostgresConnection,
executeQuery,
validateQuery,
} from '@/app/api/tools/postgresql/utils'
const logger = createLogger('PostgreSQLExecuteAPI')
const ExecuteSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
query: z.string().min(1, 'Query is required'),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
const params = ExecuteSchema.parse(body)
logger.info(
`[${requestId}] Executing raw SQL on ${params.host}:${params.port}/${params.database}`
)
const validation = validateQuery(params.query)
if (!validation.isValid) {
logger.warn(`[${requestId}] Query validation failed: ${validation.error}`)
return NextResponse.json(
{ error: `Query validation failed: ${validation.error}` },
{ status: 400 }
)
}
const sql = createPostgresConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
ssl: params.ssl,
})
try {
const result = await executeQuery(sql, params.query)
logger.info(`[${requestId}] SQL executed successfully, ${result.rowCount} row(s) affected`)
return NextResponse.json({
message: `SQL executed successfully. ${result.rowCount} row(s) affected.`,
rows: result.rows,
rowCount: result.rowCount,
})
} finally {
await sql.end()
}
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] PostgreSQL execute failed:`, error)
return NextResponse.json(
{ error: `PostgreSQL execute failed: ${errorMessage}` },
{ status: 500 }
)
}
}

View File

@@ -0,0 +1,92 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createPostgresConnection, executeInsert } from '@/app/api/tools/postgresql/utils'
const logger = createLogger('PostgreSQLInsertAPI')
const InsertSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
table: z.string().min(1, 'Table name is required'),
data: z.union([
z
.record(z.unknown())
.refine((obj) => Object.keys(obj).length > 0, 'Data object cannot be empty'),
z
.string()
.min(1)
.transform((str) => {
try {
const parsed = JSON.parse(str)
if (typeof parsed !== 'object' || parsed === null || Array.isArray(parsed)) {
throw new Error('Data must be a JSON object')
}
return parsed
} catch (e) {
const errorMsg = e instanceof Error ? e.message : 'Unknown error'
throw new Error(
`Invalid JSON format in data field: ${errorMsg}. Received: ${str.substring(0, 100)}...`
)
}
}),
]),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
const params = InsertSchema.parse(body)
logger.info(
`[${requestId}] Inserting data into ${params.table} on ${params.host}:${params.port}/${params.database}`
)
const sql = createPostgresConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
ssl: params.ssl,
})
try {
const result = await executeInsert(sql, params.table, params.data)
logger.info(`[${requestId}] Insert executed successfully, ${result.rowCount} row(s) inserted`)
return NextResponse.json({
message: `Data inserted successfully. ${result.rowCount} row(s) affected.`,
rows: result.rows,
rowCount: result.rowCount,
})
} finally {
await sql.end()
}
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] PostgreSQL insert failed:`, error)
return NextResponse.json(
{ error: `PostgreSQL insert failed: ${errorMessage}` },
{ status: 500 }
)
}
}

View File

@@ -0,0 +1,66 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createPostgresConnection, executeQuery } from '@/app/api/tools/postgresql/utils'
const logger = createLogger('PostgreSQLQueryAPI')
const QuerySchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
query: z.string().min(1, 'Query is required'),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
const params = QuerySchema.parse(body)
logger.info(
`[${requestId}] Executing PostgreSQL query on ${params.host}:${params.port}/${params.database}`
)
const sql = createPostgresConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
ssl: params.ssl,
})
try {
const result = await executeQuery(sql, params.query)
logger.info(`[${requestId}] Query executed successfully, returned ${result.rowCount} rows`)
return NextResponse.json({
message: `Query executed successfully. ${result.rowCount} row(s) returned.`,
rows: result.rows,
rowCount: result.rowCount,
})
} finally {
await sql.end()
}
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] PostgreSQL query failed:`, error)
return NextResponse.json({ error: `PostgreSQL query failed: ${errorMessage}` }, { status: 500 })
}
}

View File

@@ -0,0 +1,89 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createPostgresConnection, executeUpdate } from '@/app/api/tools/postgresql/utils'
const logger = createLogger('PostgreSQLUpdateAPI')
const UpdateSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
table: z.string().min(1, 'Table name is required'),
data: z.union([
z
.record(z.unknown())
.refine((obj) => Object.keys(obj).length > 0, 'Data object cannot be empty'),
z
.string()
.min(1)
.transform((str) => {
try {
const parsed = JSON.parse(str)
if (typeof parsed !== 'object' || parsed === null || Array.isArray(parsed)) {
throw new Error('Data must be a JSON object')
}
return parsed
} catch (e) {
throw new Error('Invalid JSON format in data field')
}
}),
]),
where: z.string().min(1, 'WHERE clause is required'),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
const params = UpdateSchema.parse(body)
logger.info(
`[${requestId}] Updating data in ${params.table} on ${params.host}:${params.port}/${params.database}`
)
const sql = createPostgresConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
ssl: params.ssl,
})
try {
const result = await executeUpdate(sql, params.table, params.data, params.where)
logger.info(`[${requestId}] Update executed successfully, ${result.rowCount} row(s) updated`)
return NextResponse.json({
message: `Data updated successfully. ${result.rowCount} row(s) affected.`,
rows: result.rows,
rowCount: result.rowCount,
})
} finally {
await sql.end()
}
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] PostgreSQL update failed:`, error)
return NextResponse.json(
{ error: `PostgreSQL update failed: ${errorMessage}` },
{ status: 500 }
)
}
}

View File

@@ -0,0 +1,194 @@
import postgres from 'postgres'
import type { PostgresConnectionConfig } from '@/tools/postgresql/types'
export function createPostgresConnection(config: PostgresConnectionConfig) {
const sslConfig =
config.ssl === 'disabled'
? false
: config.ssl === 'required'
? 'require'
: config.ssl === 'preferred'
? 'prefer'
: 'require'
const sql = postgres({
host: config.host,
port: config.port,
database: config.database,
username: config.username,
password: config.password,
ssl: sslConfig,
connect_timeout: 10, // 10 seconds
idle_timeout: 20, // 20 seconds
max_lifetime: 60 * 30, // 30 minutes
max: 1, // Single connection for tool usage
})
return sql
}
export async function executeQuery(
sql: any,
query: string,
params: unknown[] = []
): Promise<{ rows: unknown[]; rowCount: number }> {
const result = await sql.unsafe(query, params)
return {
rows: Array.isArray(result) ? result : [result],
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
}
}
export function validateQuery(query: string): { isValid: boolean; error?: string } {
const trimmedQuery = query.trim().toLowerCase()
// Block dangerous SQL operations
const dangerousPatterns = [
/drop\s+database/i,
/drop\s+schema/i,
/drop\s+user/i,
/create\s+user/i,
/create\s+role/i,
/grant\s+/i,
/revoke\s+/i,
/alter\s+user/i,
/alter\s+role/i,
/set\s+role/i,
/reset\s+role/i,
/copy\s+.*from/i,
/copy\s+.*to/i,
/lo_import/i,
/lo_export/i,
/pg_read_file/i,
/pg_write_file/i,
/pg_ls_dir/i,
/information_schema\.tables/i,
/pg_catalog/i,
/pg_user/i,
/pg_shadow/i,
/pg_roles/i,
/pg_authid/i,
/pg_stat_activity/i,
/dblink/i,
/\\\\copy/i,
]
for (const pattern of dangerousPatterns) {
if (pattern.test(query)) {
return {
isValid: false,
error: `Query contains potentially dangerous operation: ${pattern.source}`,
}
}
}
const allowedStatements = /^(select|insert|update|delete|with|explain|analyze|show)\s+/i
if (!allowedStatements.test(trimmedQuery)) {
return {
isValid: false,
error:
'Only SELECT, INSERT, UPDATE, DELETE, WITH, EXPLAIN, ANALYZE, and SHOW statements are allowed',
}
}
return { isValid: true }
}
export function sanitizeIdentifier(identifier: string): string {
if (identifier.includes('.')) {
const parts = identifier.split('.')
return parts.map((part) => sanitizeSingleIdentifier(part)).join('.')
}
return sanitizeSingleIdentifier(identifier)
}
function validateWhereClause(where: string): void {
const dangerousPatterns = [
/;\s*(drop|delete|insert|update|create|alter|grant|revoke)/i,
/union\s+select/i,
/into\s+outfile/i,
/load_file/i,
/--/,
/\/\*/,
/\*\//,
]
for (const pattern of dangerousPatterns) {
if (pattern.test(where)) {
throw new Error('WHERE clause contains potentially dangerous operation')
}
}
}
function sanitizeSingleIdentifier(identifier: string): string {
const cleaned = identifier.replace(/"/g, '')
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(cleaned)) {
throw new Error(
`Invalid identifier: ${identifier}. Identifiers must start with a letter or underscore and contain only letters, numbers, and underscores.`
)
}
return `"${cleaned}"`
}
export async function executeInsert(
sql: any,
table: string,
data: Record<string, unknown>
): Promise<{ rows: unknown[]; rowCount: number }> {
const sanitizedTable = sanitizeIdentifier(table)
const columns = Object.keys(data)
const sanitizedColumns = columns.map((col) => sanitizeIdentifier(col))
const placeholders = columns.map((_, index) => `$${index + 1}`)
const values = columns.map((col) => data[col])
const query = `INSERT INTO ${sanitizedTable} (${sanitizedColumns.join(', ')}) VALUES (${placeholders.join(', ')}) RETURNING *`
const result = await sql.unsafe(query, values)
return {
rows: Array.isArray(result) ? result : [result],
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
}
}
export async function executeUpdate(
sql: any,
table: string,
data: Record<string, unknown>,
where: string
): Promise<{ rows: unknown[]; rowCount: number }> {
validateWhereClause(where)
const sanitizedTable = sanitizeIdentifier(table)
const columns = Object.keys(data)
const sanitizedColumns = columns.map((col) => sanitizeIdentifier(col))
const setClause = sanitizedColumns.map((col, index) => `${col} = $${index + 1}`).join(', ')
const values = columns.map((col) => data[col])
const query = `UPDATE ${sanitizedTable} SET ${setClause} WHERE ${where} RETURNING *`
const result = await sql.unsafe(query, values)
return {
rows: Array.isArray(result) ? result : [result],
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
}
}
export async function executeDelete(
sql: any,
table: string,
where: string
): Promise<{ rows: unknown[]; rowCount: number }> {
validateWhereClause(where)
const sanitizedTable = sanitizeIdentifier(table)
const query = `DELETE FROM ${sanitizedTable} WHERE ${where} RETURNING *`
const result = await sql.unsafe(query, [])
return {
rows: Array.isArray(result) ? result : [result],
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
}
}

View File

@@ -1,7 +1,7 @@
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { getUserUsageLimitInfo, updateUserUsageLimit } from '@/lib/billing'
import { updateMemberUsageLimit } from '@/lib/billing/core/organization-billing'
import { getOrganizationBillingData } from '@/lib/billing/core/organization-billing'
import { createLogger } from '@/lib/logs/console/logger'
import { isOrganizationOwnerOrAdmin } from '@/lib/permissions/utils'
@@ -9,7 +9,7 @@ const logger = createLogger('UnifiedUsageLimitsAPI')
/**
* Unified Usage Limits Endpoint
* GET/PUT /api/usage-limits?context=user|member&userId=<id>&organizationId=<id>
* GET/PUT /api/usage-limits?context=user|organization&userId=<id>&organizationId=<id>
*
*/
export async function GET(request: NextRequest) {
@@ -26,40 +26,13 @@ export async function GET(request: NextRequest) {
const organizationId = searchParams.get('organizationId')
// Validate context
if (!['user', 'member'].includes(context)) {
if (!['user', 'organization'].includes(context)) {
return NextResponse.json(
{ error: 'Invalid context. Must be "user" or "member"' },
{ error: 'Invalid context. Must be "user" or "organization"' },
{ status: 400 }
)
}
// For member context, require organizationId and check permissions
if (context === 'member') {
if (!organizationId) {
return NextResponse.json(
{ error: 'Organization ID is required when context=member' },
{ status: 400 }
)
}
// Check if the current user has permission to view member usage info
const hasPermission = await isOrganizationOwnerOrAdmin(session.user.id, organizationId)
if (!hasPermission) {
logger.warn('Unauthorized attempt to view member usage info', {
requesterId: session.user.id,
targetUserId: userId,
organizationId,
})
return NextResponse.json(
{
error:
'Permission denied. Only organization owners and admins can view member usage information',
},
{ status: 403 }
)
}
}
// For user context, ensure they can only view their own info
if (context === 'user' && userId !== session.user.id) {
return NextResponse.json(
@@ -69,6 +42,23 @@ export async function GET(request: NextRequest) {
}
// Get usage limit info
if (context === 'organization') {
if (!organizationId) {
return NextResponse.json(
{ error: 'Organization ID is required when context=organization' },
{ status: 400 }
)
}
const org = await getOrganizationBillingData(organizationId)
return NextResponse.json({
success: true,
context,
userId,
organizationId,
data: org,
})
}
const usageLimitInfo = await getUserUsageLimitInfo(userId)
return NextResponse.json({
@@ -96,12 +86,11 @@ export async function PUT(request: NextRequest) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const { searchParams } = new URL(request.url)
const context = searchParams.get('context') || 'user'
const userId = searchParams.get('userId') || session.user.id
const organizationId = searchParams.get('organizationId')
const { limit } = await request.json()
const body = await request.json()
const limit = body?.limit
const context = body?.context || 'user'
const organizationId = body?.organizationId
const userId = session.user.id
if (typeof limit !== 'number' || limit < 0) {
return NextResponse.json(
@@ -110,52 +99,42 @@ export async function PUT(request: NextRequest) {
)
}
if (!['user', 'organization'].includes(context)) {
return NextResponse.json(
{ error: 'Invalid context. Must be "user" or "organization"' },
{ status: 400 }
)
}
if (context === 'user') {
// Update user's own usage limit
if (userId !== session.user.id) {
return NextResponse.json({ error: "Cannot update other users' limits" }, { status: 403 })
}
await updateUserUsageLimit(userId, limit)
} else if (context === 'member') {
// Update organization member's usage limit
} else if (context === 'organization') {
// context === 'organization'
if (!organizationId) {
return NextResponse.json(
{ error: 'Organization ID is required when context=member' },
{ error: 'Organization ID is required when context=organization' },
{ status: 400 }
)
}
// Check if the current user has permission to update member limits
const hasPermission = await isOrganizationOwnerOrAdmin(session.user.id, organizationId)
if (!hasPermission) {
logger.warn('Unauthorized attempt to update member usage limit', {
adminUserId: session.user.id,
targetUserId: userId,
organizationId,
})
return NextResponse.json(
{
error:
'Permission denied. Only organization owners and admins can update member usage limits',
},
{ status: 403 }
)
return NextResponse.json({ error: 'Permission denied' }, { status: 403 })
}
logger.info('Authorized member usage limit update', {
adminUserId: session.user.id,
targetUserId: userId,
organizationId,
newLimit: limit,
})
await updateMemberUsageLimit(organizationId, userId, limit, session.user.id)
} else {
return NextResponse.json(
{ error: 'Invalid context. Must be "user" or "member"' },
{ status: 400 }
// Use the dedicated function to update org usage limit
const { updateOrganizationUsageLimit } = await import(
'@/lib/billing/core/organization-billing'
)
const result = await updateOrganizationUsageLimit(organizationId, limit)
if (!result.success) {
return NextResponse.json({ error: result.error }, { status: 400 })
}
const updated = await getOrganizationBillingData(organizationId)
return NextResponse.json({ success: true, context, userId, organizationId, data: updated })
}
// Return updated limit info

View File

@@ -0,0 +1,38 @@
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { checkServerSideUsageLimits } from '@/lib/billing'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('UsageCheckAPI')
export async function GET(_request: NextRequest) {
const session = await getSession()
try {
const userId = session?.user?.id
if (!userId) return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
const result = await checkServerSideUsageLimits(userId)
// Normalize to client usage shape
return NextResponse.json({
success: true,
data: {
percentUsed:
result.limit > 0
? Math.min(Math.floor((result.currentUsage / result.limit) * 100), 100)
: 0,
isWarning:
result.limit > 0
? (result.currentUsage / result.limit) * 100 >= 80 &&
(result.currentUsage / result.limit) * 100 < 100
: false,
isExceeded: result.isExceeded,
currentUsage: result.currentUsage,
limit: result.limit,
message: result.message,
},
})
} catch (error) {
logger.error('Failed usage check', { error })
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
}
}

View File

@@ -1,11 +1,10 @@
import { unstable_noStore as noStore } from 'next/cache'
import { type NextRequest, NextResponse } from 'next/server'
import OpenAI, { AzureOpenAI } from 'openai'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
export const dynamic = 'force-dynamic'
export const runtime = 'edge'
export const runtime = 'nodejs'
export const maxDuration = 60
const logger = createLogger('WandGenerateAPI')
@@ -50,6 +49,15 @@ interface RequestBody {
history?: ChatMessage[]
}
// Helper: safe stringify for error payloads that may include circular structures
function safeStringify(value: unknown): string {
try {
return JSON.stringify(value)
} catch {
return '[unserializable]'
}
}
export async function POST(req: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
logger.info(`[${requestId}] Received wand generation request`)
@@ -63,7 +71,6 @@ export async function POST(req: NextRequest) {
}
try {
noStore()
const body = (await req.json()) as RequestBody
const { prompt, systemPrompt, stream = false, history = [] } = body
@@ -108,60 +115,176 @@ export async function POST(req: NextRequest) {
`[${requestId}] Starting streaming request to ${useWandAzure ? 'Azure OpenAI' : 'OpenAI'}`
)
const streamCompletion = await client.chat.completions.create({
model: useWandAzure ? wandModelName : 'gpt-4o',
messages: messages,
temperature: 0.3,
max_tokens: 10000,
stream: true,
logger.info(
`[${requestId}] About to create stream with model: ${useWandAzure ? wandModelName : 'gpt-4o'}`
)
// Use native fetch for streaming to avoid OpenAI SDK issues with Node.js runtime
const apiUrl = useWandAzure
? `${azureEndpoint}/openai/deployments/${wandModelName}/chat/completions?api-version=${azureApiVersion}`
: 'https://api.openai.com/v1/chat/completions'
const headers: Record<string, string> = {
'Content-Type': 'application/json',
}
if (useWandAzure) {
headers['api-key'] = azureApiKey!
} else {
headers.Authorization = `Bearer ${openaiApiKey}`
}
logger.debug(`[${requestId}] Making streaming request to: ${apiUrl}`)
const response = await fetch(apiUrl, {
method: 'POST',
headers,
body: JSON.stringify({
model: useWandAzure ? wandModelName : 'gpt-4o',
messages: messages,
temperature: 0.3,
max_tokens: 10000,
stream: true,
stream_options: { include_usage: true },
}),
})
logger.debug(`[${requestId}] Stream connection established successfully`)
if (!response.ok) {
const errorText = await response.text()
logger.error(`[${requestId}] API request failed`, {
status: response.status,
statusText: response.statusText,
error: errorText,
})
throw new Error(`API request failed: ${response.status} ${response.statusText}`)
}
return new Response(
new ReadableStream({
async start(controller) {
const encoder = new TextEncoder()
logger.info(`[${requestId}] Stream response received, starting processing`)
try {
for await (const chunk of streamCompletion) {
const content = chunk.choices[0]?.delta?.content || ''
if (content) {
// Use SSE format identical to chat streaming
controller.enqueue(
encoder.encode(`data: ${JSON.stringify({ chunk: content })}\n\n`)
)
}
// Create a TransformStream to process the SSE data
const encoder = new TextEncoder()
const decoder = new TextDecoder()
const readable = new ReadableStream({
async start(controller) {
const reader = response.body?.getReader()
if (!reader) {
controller.close()
return
}
try {
let buffer = ''
let chunkCount = 0
while (true) {
const { done, value } = await reader.read()
if (done) {
logger.info(`[${requestId}] Stream completed. Total chunks: ${chunkCount}`)
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
controller.close()
break
}
// Send completion signal in SSE format
controller.enqueue(encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`))
controller.close()
logger.info(`[${requestId}] Wand generation streaming completed`)
} catch (streamError: any) {
logger.error(`[${requestId}] Streaming error`, { error: streamError.message })
controller.enqueue(
encoder.encode(
`data: ${JSON.stringify({ error: 'Streaming failed', done: true })}\n\n`
)
)
controller.close()
// Decode the chunk
buffer += decoder.decode(value, { stream: true })
// Process complete SSE messages
const lines = buffer.split('\n')
buffer = lines.pop() || '' // Keep incomplete line in buffer
for (const line of lines) {
if (line.startsWith('data: ')) {
const data = line.slice(6).trim()
if (data === '[DONE]') {
logger.info(`[${requestId}] Received [DONE] signal`)
controller.enqueue(
encoder.encode(`data: ${JSON.stringify({ done: true })}\n\n`)
)
controller.close()
return
}
try {
const parsed = JSON.parse(data)
const content = parsed.choices?.[0]?.delta?.content
if (content) {
chunkCount++
if (chunkCount === 1) {
logger.info(`[${requestId}] Received first content chunk`)
}
// Forward the content
controller.enqueue(
encoder.encode(`data: ${JSON.stringify({ chunk: content })}\n\n`)
)
}
// Log usage if present
if (parsed.usage) {
logger.info(
`[${requestId}] Received usage data: ${JSON.stringify(parsed.usage)}`
)
}
// Log progress periodically
if (chunkCount % 10 === 0) {
logger.debug(`[${requestId}] Processed ${chunkCount} chunks`)
}
} catch (parseError) {
// Skip invalid JSON lines
logger.debug(
`[${requestId}] Skipped non-JSON line: ${data.substring(0, 100)}`
)
}
}
}
}
},
}),
{
headers: {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
'X-Accel-Buffering': 'no',
},
}
)
logger.info(`[${requestId}] Wand generation streaming completed successfully`)
} catch (streamError: any) {
logger.error(`[${requestId}] Streaming error`, {
name: streamError?.name,
message: streamError?.message || 'Unknown error',
stack: streamError?.stack,
})
// Send error to client
const errorData = `data: ${JSON.stringify({ error: 'Streaming failed', done: true })}\n\n`
controller.enqueue(encoder.encode(errorData))
controller.close()
} finally {
reader.releaseLock()
}
},
})
// Return Response with proper headers for Node.js runtime
return new Response(readable, {
headers: {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache, no-transform',
Connection: 'keep-alive',
'X-Accel-Buffering': 'no', // Disable Nginx buffering
'Transfer-Encoding': 'chunked', // Important for Node.js runtime
},
})
} catch (error: any) {
logger.error(`[${requestId}] Streaming error`, {
error: error.message || 'Unknown error',
stack: error.stack,
logger.error(`[${requestId}] Failed to create stream`, {
name: error?.name,
message: error?.message || 'Unknown error',
code: error?.code,
status: error?.status,
responseStatus: error?.response?.status,
responseData: error?.response?.data ? safeStringify(error.response.data) : undefined,
stack: error?.stack,
useWandAzure,
model: useWandAzure ? wandModelName : 'gpt-4o',
endpoint: useWandAzure ? azureEndpoint : 'api.openai.com',
apiVersion: useWandAzure ? azureApiVersion : 'N/A',
})
return NextResponse.json(
@@ -195,8 +318,19 @@ export async function POST(req: NextRequest) {
return NextResponse.json({ success: true, content: generatedContent })
} catch (error: any) {
logger.error(`[${requestId}] Wand generation failed`, {
error: error.message || 'Unknown error',
stack: error.stack,
name: error?.name,
message: error?.message || 'Unknown error',
code: error?.code,
status: error?.status,
responseStatus: error instanceof OpenAI.APIError ? error.status : error?.response?.status,
responseData: (error as any)?.response?.data
? safeStringify((error as any).response.data)
: undefined,
stack: error?.stack,
useWandAzure,
model: useWandAzure ? wandModelName : 'gpt-4o',
endpoint: useWandAzure ? azureEndpoint : 'api.openai.com',
apiVersion: useWandAzure ? azureApiVersion : 'N/A',
})
let clientErrorMessage = 'Wand generation failed. Please try again later.'

View File

@@ -495,7 +495,9 @@ async function createAirtableWebhookSubscription(
} else {
logger.info(
`[${requestId}] Successfully created webhook in Airtable for webhook ${webhookData.id}.`,
{ airtableWebhookId: responseBody.id }
{
airtableWebhookId: responseBody.id,
}
)
// Store the airtableWebhookId (responseBody.id) within the providerConfig
try {

View File

@@ -11,7 +11,6 @@ export async function GET(req: NextRequest) {
const token = req.nextUrl.searchParams.get('token')
if (!token) {
// Redirect to a page explaining the error
return NextResponse.redirect(
new URL(
'/invite/invite-error?reason=missing-token',
@@ -68,40 +67,39 @@ export async function GET(req: NextRequest) {
const userEmail = session.user.email.toLowerCase()
const invitationEmail = invitation.email.toLowerCase()
// Check if the logged-in user's email matches the invitation
// We'll use exact matching as the primary check
const isExactMatch = userEmail === invitationEmail
// For SSO or company email variants, check domain and normalized username
// This handles cases like john.doe@company.com vs john@company.com
const normalizeUsername = (email: string): string => {
return email
.split('@')[0]
.replace(/[^a-zA-Z0-9]/g, '')
.toLowerCase()
}
const isSameDomain = userEmail.split('@')[1] === invitationEmail.split('@')[1]
const normalizedUserEmail = normalizeUsername(userEmail)
const normalizedInvitationEmail = normalizeUsername(invitationEmail)
const isSimilarUsername =
normalizedUserEmail === normalizedInvitationEmail ||
normalizedUserEmail.includes(normalizedInvitationEmail) ||
normalizedInvitationEmail.includes(normalizedUserEmail)
const isValidMatch = isExactMatch || (isSameDomain && isSimilarUsername)
if (!isValidMatch) {
// Get user info to include in the error message
const userData = await db
.select()
.from(user)
.where(eq(user.id, session.user.id))
.then((rows) => rows[0])
// Get user data to check email verification status and for error messages
const userData = await db
.select()
.from(user)
.where(eq(user.id, session.user.id))
.then((rows) => rows[0])
if (!userData) {
return NextResponse.redirect(
new URL(
`/invite/invite-error?reason=email-mismatch&details=${encodeURIComponent(`Invitation was sent to ${invitation.email}, but you're logged in as ${userData?.email || session.user.email}`)}`,
'/invite/invite-error?reason=user-not-found',
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
// Check if user's email is verified
if (!userData.emailVerified) {
return NextResponse.redirect(
new URL(
`/invite/invite-error?reason=email-not-verified&details=${encodeURIComponent(`You must verify your email address (${userData.email}) before accepting invitations.`)}`,
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
// Check if the logged-in user's email matches the invitation
const isValidMatch = userEmail === invitationEmail
if (!isValidMatch) {
return NextResponse.redirect(
new URL(
`/invite/invite-error?reason=email-mismatch&details=${encodeURIComponent(`Invitation was sent to ${invitation.email}, but you're logged in as ${userData.email}`)}`,
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)

View File

@@ -122,8 +122,8 @@
--popover-foreground: 0 0% 98%;
/* Primary Colors */
--primary: 0 0% 98%;
--primary-foreground: 0 0% 11.2%;
--primary: 0 0% 11.2%;
--primary-foreground: 0 0% 98%;
/* Secondary Colors */
--secondary: 0 0% 12.0%;

View File

@@ -1,12 +1,13 @@
'use client'
import { useEffect, useState } from 'react'
import { BotIcon, CheckCircle } from 'lucide-react'
import { AlertCircle, CheckCircle2, Mail, UserPlus, Users2 } from 'lucide-react'
import Image from 'next/image'
import { useParams, useRouter, useSearchParams } from 'next/navigation'
import { Button } from '@/components/ui/button'
import { Card, CardDescription, CardFooter, CardHeader, CardTitle } from '@/components/ui/card'
import { LoadingAgent } from '@/components/ui/loading-agent'
import { client, useSession } from '@/lib/auth-client'
import { useBrandConfig } from '@/lib/branding/branding'
export default function Invite() {
const router = useRouter()
@@ -14,6 +15,7 @@ export default function Invite() {
const inviteId = params.id as string
const searchParams = useSearchParams()
const { data: session, isPending } = useSession()
const brandConfig = useBrandConfig()
const [invitationDetails, setInvitationDetails] = useState<any>(null)
const [isLoading, setIsLoading] = useState(true)
const [error, setError] = useState<string | null>(null)
@@ -174,28 +176,46 @@ export default function Invite() {
const callbackUrl = encodeURIComponent(getCallbackUrl())
return (
<div className='flex min-h-screen flex-col items-center justify-center bg-muted/40 p-4'>
<Card className='w-full max-w-md p-6'>
<CardHeader className='px-0 pt-0 text-center'>
<CardTitle>You've been invited to join a workspace</CardTitle>
<CardDescription>
{isNewUser
? 'Create an account to join this workspace on Sim'
: 'Sign in to your account to accept this invitation'}
</CardDescription>
</CardHeader>
<CardFooter className='flex flex-col space-y-2 px-0'>
<div className='flex min-h-screen flex-col items-center justify-center bg-white px-4 dark:bg-black'>
<div className='mb-8'>
<Image
src={brandConfig.logoUrl || '/logo/b&w/medium.png'}
alt='Sim Logo'
width={120}
height={67}
className='dark:invert'
priority
/>
</div>
<div className='flex w-full max-w-md flex-col items-center text-center'>
<div className='mb-6 rounded-full bg-blue-50 p-3 dark:bg-blue-950/20'>
<UserPlus className='h-8 w-8 text-blue-500 dark:text-blue-400' />
</div>
<h1 className='mb-2 font-semibold text-black text-xl dark:text-white'>
You've been invited!
</h1>
<p className='mb-6 text-gray-600 text-sm leading-relaxed dark:text-gray-300'>
{isNewUser
? 'Create an account to join this workspace on Sim'
: 'Sign in to your account to accept this invitation'}
</p>
<div className='flex w-full flex-col gap-3'>
{isNewUser ? (
<>
<Button
className='w-full'
style={{ backgroundColor: 'var(--brand-primary-hex)', color: 'white' }}
onClick={() => router.push(`/signup?callbackUrl=${callbackUrl}&invite_flow=true`)}
>
Create an account
</Button>
<Button
variant='outline'
className='w-full'
className='w-full border-brand-primary text-brand-primary hover:bg-brand-primary hover:text-white'
onClick={() => router.push(`/login?callbackUrl=${callbackUrl}&invite_flow=true`)}
>
I already have an account
@@ -205,13 +225,14 @@ export default function Invite() {
<>
<Button
className='w-full'
style={{ backgroundColor: 'var(--brand-primary-hex)', color: 'white' }}
onClick={() => router.push(`/login?callbackUrl=${callbackUrl}&invite_flow=true`)}
>
Sign in
</Button>
<Button
variant='outline'
className='w-full'
className='w-full border-brand-primary text-brand-primary hover:bg-brand-primary hover:text-white'
onClick={() =>
router.push(`/signup?callbackUrl=${callbackUrl}&invite_flow=true&new=true`)
}
@@ -220,8 +241,23 @@ export default function Invite() {
</Button>
</>
)}
</CardFooter>
</Card>
<Button
className='w-full'
style={{ backgroundColor: 'var(--brand-primary-hex)', color: 'white' }}
onClick={() => router.push('/')}
>
Return to Home
</Button>
</div>
</div>
<footer className='mt-8 text-center text-gray-500 text-xs'>
Need help?{' '}
<a href='mailto:help@sim.ai' className='text-blue-400 hover:text-blue-300'>
Contact support
</a>
</footer>
</div>
)
}
@@ -229,9 +265,26 @@ export default function Invite() {
// Show loading state
if (isLoading || isPending) {
return (
<div className='flex min-h-screen flex-col items-center justify-center bg-muted/40 p-4'>
<div className='flex min-h-screen flex-col items-center justify-center bg-white px-4 dark:bg-black'>
<div className='mb-8'>
<Image
src={brandConfig.logoUrl || '/logo/b&w/medium.png'}
alt='Sim Logo'
width={120}
height={67}
className='dark:invert'
priority
/>
</div>
<LoadingAgent size='lg' />
<p className='mt-4 text-muted-foreground text-sm'>Loading invitation...</p>
<p className='mt-4 text-gray-400 text-sm'>Loading invitation...</p>
<footer className='mt-8 text-center text-gray-500 text-xs'>
Need help?{' '}
<a href='mailto:help@sim.ai' className='text-blue-400 hover:text-blue-300'>
Contact support
</a>
</footer>
</div>
)
}
@@ -239,14 +292,41 @@ export default function Invite() {
// Show error state
if (error) {
return (
<div className='flex min-h-screen flex-col items-center justify-center bg-muted/40 p-4'>
<Card className='max-w-md space-y-2 p-6 text-center'>
<div className='flex justify-center'>
<BotIcon className='h-16 w-16 text-muted-foreground' />
<div className='flex min-h-screen flex-col items-center justify-center bg-white px-4 dark:bg-black'>
<div className='mb-8'>
<Image
src={brandConfig.logoUrl || '/logo/b&w/medium.png'}
alt='Sim Logo'
width={120}
height={67}
className='dark:invert'
priority
/>
</div>
<div className='flex w-full max-w-md flex-col items-center text-center'>
<div className='mb-6 rounded-full bg-red-50 p-3 dark:bg-red-950/20'>
<AlertCircle className='h-8 w-8 text-red-500 dark:text-red-400' />
</div>
<h3 className='font-semibold text-lg'>Invitation Error</h3>
<p className='text-muted-foreground'>{error}</p>
</Card>
<h1 className='mb-2 font-semibold text-black text-xl dark:text-white'>
Invitation Error
</h1>
<p className='mb-6 text-gray-600 text-sm leading-relaxed dark:text-gray-300'>{error}</p>
<Button
className='w-full'
style={{ backgroundColor: 'var(--brand-primary-hex)', color: 'white' }}
onClick={() => router.push('/')}
>
Return to Home
</Button>
</div>
<footer className='mt-8 text-center text-gray-500 text-xs'>
Need help?{' '}
<a href='mailto:help@sim.ai' className='text-blue-400 hover:text-blue-300'>
Contact support
</a>
</footer>
</div>
)
}
@@ -254,41 +334,113 @@ export default function Invite() {
// Show success state
if (accepted) {
return (
<div className='flex min-h-screen flex-col items-center justify-center bg-muted/40 p-4'>
<Card className='max-w-md space-y-2 p-6 text-center'>
<div className='flex justify-center'>
<CheckCircle className='h-16 w-16 text-green-500' />
<div className='flex min-h-screen flex-col items-center justify-center bg-white px-4 dark:bg-black'>
<div className='mb-8'>
<Image
src={brandConfig.logoUrl || '/logo/b&w/medium.png'}
alt='Sim Logo'
width={120}
height={67}
className='dark:invert'
priority
/>
</div>
<div className='flex w-full max-w-md flex-col items-center text-center'>
<div className='mb-6 rounded-full bg-green-50 p-3 dark:bg-green-950/20'>
<CheckCircle2 className='h-8 w-8 text-green-500 dark:text-green-400' />
</div>
<h3 className='font-semibold text-lg'>Invitation Accepted</h3>
<p className='text-muted-foreground'>
<h1 className='mb-2 font-semibold text-black text-xl dark:text-white'>Welcome!</h1>
<p className='mb-6 text-gray-600 text-sm leading-relaxed dark:text-gray-300'>
You have successfully joined {invitationDetails?.name || 'the workspace'}. Redirecting
to your workspace...
</p>
</Card>
<Button
className='w-full'
style={{ backgroundColor: 'var(--brand-primary-hex)', color: 'white' }}
onClick={() => router.push('/')}
>
Return to Home
</Button>
</div>
<footer className='mt-8 text-center text-gray-500 text-xs'>
Need help?{' '}
<a href='mailto:help@sim.ai' className='text-blue-400 hover:text-blue-300'>
Contact support
</a>
</footer>
</div>
)
}
// Show invitation details
return (
<div className='flex min-h-screen flex-col items-center justify-center bg-muted/40 p-4'>
<Card className='w-full max-w-md'>
<CardHeader className='text-center'>
<CardTitle className='mb-1'>Workspace Invitation</CardTitle>
<CardDescription className='text-md'>
You've been invited to join{' '}
<span className='font-medium'>{invitationDetails?.name || 'a workspace'}</span>
</CardDescription>
<p className='mt-2 text-md text-muted-foreground'>
Click the accept below to join the workspace.
</p>
</CardHeader>
<CardFooter className='flex justify-center'>
<Button onClick={handleAcceptInvitation} disabled={isAccepting} className='w-full'>
<span className='ml-2'>{isAccepting ? '' : ''}Accept Invitation</span>
<div className='flex min-h-screen flex-col items-center justify-center bg-white px-4 dark:bg-black'>
<div className='mb-8'>
<Image
src='/logo/b&w/medium.png'
alt='Sim Logo'
width={120}
height={67}
className='dark:invert'
priority
/>
</div>
<div className='flex w-full max-w-md flex-col items-center text-center'>
<div className='mb-6 rounded-full bg-blue-50 p-3 dark:bg-blue-950/20'>
{invitationType === 'organization' ? (
<Users2 className='h-8 w-8 text-blue-500 dark:text-blue-400' />
) : (
<Mail className='h-8 w-8 text-blue-500 dark:text-blue-400' />
)}
</div>
<h1 className='mb-2 font-semibold text-black text-xl dark:text-white'>
{invitationType === 'organization' ? 'Organization Invitation' : 'Workspace Invitation'}
</h1>
<p className='mb-6 text-gray-600 text-sm leading-relaxed dark:text-gray-300'>
You've been invited to join{' '}
<span className='font-medium text-black dark:text-white'>
{invitationDetails?.name || `a ${invitationType}`}
</span>
. Click accept below to join.
</p>
<div className='flex w-full flex-col gap-3'>
<Button
onClick={handleAcceptInvitation}
disabled={isAccepting}
className='w-full'
style={{ backgroundColor: 'var(--brand-primary-hex)', color: 'white' }}
>
{isAccepting ? (
<>
<LoadingAgent size='sm' />
Accepting...
</>
) : (
'Accept Invitation'
)}
</Button>
</CardFooter>
</Card>
<Button
variant='ghost'
className='w-full text-gray-600 hover:bg-gray-200 hover:text-black dark:text-gray-400 dark:hover:bg-gray-800 dark:hover:text-white'
onClick={() => router.push('/')}
>
Return to Home
</Button>
</div>
</div>
<footer className='mt-8 text-center text-gray-500 text-xs'>
Need help?{' '}
<a href='mailto:help@sim.ai' className='text-blue-400 hover:text-blue-300'>
Contact support
</a>
</footer>
</div>
)
}

View File

@@ -1,10 +1,12 @@
'use client'
import { useEffect, useState } from 'react'
import { AlertTriangle } from 'lucide-react'
import { Mail, RotateCcw, ShieldX } from 'lucide-react'
import Image from 'next/image'
import Link from 'next/link'
import { useSearchParams } from 'next/navigation'
import { Button } from '@/components/ui/button'
import { useBrandConfig } from '@/lib/branding/branding'
function getErrorMessage(reason: string, details?: string): string {
switch (reason) {
@@ -22,6 +24,18 @@ function getErrorMessage(reason: string, details?: string): string {
: 'This invitation was sent to a different email address than the one you are logged in with.'
case 'workspace-not-found':
return 'The workspace associated with this invitation could not be found.'
case 'user-not-found':
return 'Your user account could not be found. Please try logging out and logging back in.'
case 'email-not-verified':
return details
? details
: 'You must verify your email address before accepting invitations. Please check your email for a verification link.'
case 'already-member':
return 'You are already a member of this organization or workspace.'
case 'invalid-invitation':
return 'This invitation is invalid or no longer exists.'
case 'missing-invitation-id':
return 'The invitation link is missing required information. Please use the original invitation link.'
case 'server-error':
return 'An unexpected error occurred while processing your invitation. Please try again later.'
default:
@@ -34,6 +48,7 @@ export default function InviteError() {
const reason = searchParams?.get('reason') || 'unknown'
const details = searchParams?.get('details')
const [errorMessage, setErrorMessage] = useState('')
const brandConfig = useBrandConfig()
useEffect(() => {
// Only set the error message on the client side
@@ -43,31 +58,79 @@ export default function InviteError() {
// Provide a fallback message for SSR
const displayMessage = errorMessage || 'Loading error details...'
const isEmailVerificationError = reason === 'email-not-verified'
const isExpiredError = reason === 'expired'
return (
<div className='flex min-h-screen flex-col items-center justify-center'>
<div className='mx-auto max-w-md rounded-lg border bg-card px-6 py-12'>
<div className='flex flex-col items-center text-center'>
<AlertTriangle className='mb-4 h-12 w-12 text-amber-500' />
<div className='flex min-h-screen flex-col items-center justify-center bg-white px-4 dark:bg-black'>
{/* Logo */}
<div className='mb-8'>
<Image
src={brandConfig.logoUrl || '/logo/b&w/medium.png'}
alt='Sim Logo'
width={120}
height={67}
className='dark:invert'
priority
/>
</div>
<h1 className='mb-2 font-bold text-2xl tracking-tight'>Invitation Error</h1>
<div className='flex w-full max-w-md flex-col items-center text-center'>
<div className='mb-6 rounded-full bg-red-50 p-3 dark:bg-red-950/20'>
<ShieldX className='h-8 w-8 text-red-500 dark:text-red-400' />
</div>
<p className='mb-6 text-muted-foreground'>{displayMessage}</p>
<h1 className='mb-2 font-semibold text-black text-xl dark:text-white'>Invitation Error</h1>
<div className='flex w-full flex-col gap-4'>
<Link href='/workspace' passHref>
<Button variant='default' className='w-full'>
Go to Dashboard
</Button>
</Link>
<p className='mb-6 text-gray-600 text-sm leading-relaxed dark:text-gray-300'>
{displayMessage}
</p>
<Link href='/' passHref>
<Button variant='outline' className='w-full'>
Return to Home
</Button>
</Link>
</div>
<div className='flex w-full flex-col gap-3'>
{isEmailVerificationError && (
<Button
variant='default'
className='w-full'
style={{ backgroundColor: 'var(--brand-primary-hex)', color: 'white' }}
asChild
>
<Link href='/verify'>
<Mail className='mr-2 h-4 w-4' />
Verify Email
</Link>
</Button>
)}
{isExpiredError && (
<Button
variant='outline'
className='w-full border-brand-primary text-brand-primary hover:bg-brand-primary hover:text-white'
asChild
>
<Link href='/'>
<RotateCcw className='mr-2 h-4 w-4' />
Request New Invitation
</Link>
</Button>
)}
<Button
className='w-full'
style={{ backgroundColor: 'var(--brand-primary-hex)', color: 'white' }}
asChild
>
<Link href='/'>Return to Home</Link>
</Button>
</div>
</div>
<footer className='mt-8 text-center text-gray-500 text-xs'>
Need help?{' '}
<a href='mailto:help@sim.ai' className='text-blue-400 hover:text-blue-300'>
Contact support
</a>
</footer>
</div>
)
}

Some files were not shown because too many files have changed in this diff Show More