Compare commits

...

123 Commits

Author SHA1 Message Date
Waleed
60a061e38a v0.3.47: race condition fixes, store rehydration consolidation, other bugs 2025-09-04 22:36:42 -07:00
Waleed
ab71fcfc49 feat(invitations): add ability to resend invitations with cooldown, fixed UI in dark mode issues (#1256) 2025-09-04 22:15:27 -07:00
Vikhyath Mondreti
864622c1dc fix(ratelimits): enterprise and team checks should be pooled limit (#1255)
* fix(ratelimits): enterprise and team checks should be pooled limit"

* fix

* fix dynamic imports

* fix tests"
;
2025-09-04 21:44:56 -07:00
Waleed
8668622d66 feat(duplicate): duplicate variables when duplicating a workflow (#1254)
* feat(duplicate): duplicate variables when duplicating a workflow

* better typing
2025-09-04 21:20:30 -07:00
Waleed
53dd277cfe fix(cost): restored cost reporting for agent block in console entry (#1253) 2025-09-04 21:12:15 -07:00
Vikhyath Mondreti
0e8e8c7a47 fix(sidebar): order by created at (#1251) 2025-09-04 20:23:00 -07:00
Vikhyath Mondreti
47da5eb6e8 fix(rehydration): consolidate store rehydration code (#1249)
* fix(rehydration): consolidate store rehydration code

* fix stale closure
2025-09-04 20:00:51 -07:00
Vikhyath Mondreti
37dcde2afc feat(enterprise-plan-webhooks): skip webhook queue for enterprise plan users (#1250)
* feat(enterprise-plan-webhooks): skip webhook queue for enterprise plan users

* reuse subscription record instead of making extra db call
2025-09-04 20:00:24 -07:00
Vikhyath Mondreti
e31627c7c2 fix(sidebar): re-ordering based on last edit is confusing (#1248) 2025-09-04 18:30:59 -07:00
Vikhyath Mondreti
57c98d86ba fix(race-condition-workflow-switching): another race condition between registry and workflow stores (#1247)
* fix(race-condition-workflow-switching): another race condition between regitry and workflow stores"

* fix initial load race cond + cleanup

* fix initial load issue + simplify
2025-09-04 18:02:00 -07:00
Vikhyath Mondreti
0f7dfe084a fix(hydration): duplicate overlay after idle + subblocks race condition (#1246)
* fix(hydration): duplicate overlay after idle + subblocks race condition

* remove random timeout

* re-use correct helper

* remove redundant check

* add check

* remove third init func
2025-09-04 16:18:35 -07:00
Siddharth Ganesan
afc1632830 Merge pull request #1245 from simstudioai/fix/copilot-billing
improvement(copilot): billing multiplier adjustments
2025-09-04 12:05:17 -07:00
Siddharth Ganesan
56eee2c2d2 Waring 2025-09-04 11:37:06 -07:00
Siddharth Ganesan
fc558a8eef Lint + tests 2025-09-04 11:35:03 -07:00
Siddharth Ganesan
c68cadfb84 Docs 2025-09-04 11:27:54 -07:00
Siddharth Ganesan
95d93a2532 change 2025-09-04 11:23:36 -07:00
Siddharth Ganesan
59b2023124 Lint 2025-09-04 11:19:41 -07:00
Siddharth Ganesan
a672f17136 Add input/output multipliers 2025-09-04 11:19:00 -07:00
Waleed
1de59668e4 fix(whitelabel): move redirects (build-time) for whitelabeling to middlware (runtime) (#1236) 2025-09-03 16:36:47 -07:00
Waleed
26243b99e8 fix(code-subblock): added validation to not parse non-variables as variables in the code subblock (#1240)
* fix(code-subblock): added validation to not parse non-variables as variables in the code subblock

* fix wand prompt bar styling

* fix error message for available connected blocks to only show connected available blocks, not block ID's

* ui
2025-09-03 16:09:02 -07:00
Siddharth Ganesan
fce1423d05 v0.3.46: fix copilot stats updates
v0.3.46: fix copilot stats updates
2025-09-03 13:26:00 -07:00
Siddharth Ganesan
3656d3d7ad Updates (#1237) 2025-09-03 13:19:34 -07:00
Waleed
581929bc01 v0.3.45: fixes for organization invites, custom tool execution 2025-09-03 08:31:56 -07:00
Waleed
11d8188415 fix(rce): always use VM over RCE for custom tools (#1233) 2025-09-03 08:16:50 -07:00
Waleed
36c98d18e9 fix(team): fix organization invitation URL for teams (#1232) 2025-09-03 08:05:38 -07:00
Waleed
0cf87e650d v0.3.44: removing unused routes, whitelabeling terms & policy URLs, e2b remote code execution, copilot improvements 2025-09-02 21:29:55 -07:00
Waleed
baef8d77f9 fix(styling): fix styling inconsistencies in dark mode, fix invites fetching to show active members (#1229)
* fix(styling): fix unreadble text in dark mode

* fix styling inconsistencies in kb

* refetch permissions on invite modal open

---------

Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
2025-09-02 21:17:15 -07:00
Vikhyath Mondreti
b74ab46820 fix(e2b-env-var): use isTruthy and getEnv (#1228) 2025-09-02 20:03:43 -07:00
Waleed
533b4c53e0 feat(tools): add MongoDB (#1225)
* added mongo, haven't tested

* fixed bugs, refined prompts, added billing for wand if billing enabled

* add docs

* ack PR comments
2025-09-02 18:55:45 -07:00
Siddharth Ganesan
c2d668c3eb feat(copilot): stats tracking (#1227)
* Add copilot stats table schema

* Move db to agent

* Lint

* Fix tests
2025-09-02 18:17:50 -07:00
Vikhyath Mondreti
1a5d5ddffa feat(e2b-execution): add remote code execution to support Python + Imports (#1226)
* feat(e2b-execution): add remote code execution via e2b

* ux improvements

* fix streaming

* progress

* fix tooltip text

* make supported languages an enum

* fix error handling

* fix tests
2025-09-02 18:15:29 -07:00
Waleed
9de0d91f9a feat(llms): added additional params to llm-based blocks for alternative models (#1223)
* feat(llms): added additional params to llm-based blocks for alternative models

* add hidden temp param to other LLM-based blocks
2025-09-02 13:29:03 -07:00
Waleed
3db73ff721 fix(whitelabel): make terms and privacy URL envvars available at build time (#1222) 2025-09-02 12:54:30 -07:00
Vikhyath Mondreti
9ffb48ee02 make 79th migration idempotent 2025-09-02 10:48:22 -07:00
Vikhyath Mondreti
1f2a317ac2 fix if not exists check 2025-09-02 10:39:53 -07:00
Vikhyath Mondreti
a618d289d8 add if not exists check 2025-09-02 10:38:33 -07:00
Vikhyath Mondreti
461d7b2342 Merge branch 'staging' of github.com:simstudioai/sim into staging 2025-09-02 10:27:08 -07:00
Vikhyath Mondreti
4273161c0f fix 80th migration 2025-09-02 10:26:57 -07:00
Waleed
54d42b33eb fix(wand): remove duplicate transfer encoding header meant to be set by nginx proxy (#1221) 2025-09-02 09:15:25 -07:00
Waleed
2c2c32c64b improvement(hygiene): refactored routes to be more restful, reduced code surface area and removed redundant code (#1217)
* improvement(invitations): consolidate invite-error and invite pages, made API endpoints more restful and reduced code surface area for invitations by 50%

* refactored logs API routes

* refactor rate limit api route, consolidate usage check api endpoint

* refactored chat page and invitations page

* consolidate ollama and openrouter stores to just providers store

* removed unused route

* removed legacy envvar methods

* remove dead, legacy routes for invitations PUT and workflow SYNC

* improvement(copilot): improve context inputs and fix some bugs (#1216)

* Add logs v1

* Update

* Updates

* Updates

* Fixes

* Fix current workflow in context

* Fix mentions

* Error handling

* Fix chat loading

* Hide current workflow from context

* Run workflow fix

* Lint

* updated invitation log

* styling for invitation pages

---------

Co-authored-by: Siddharth Ganesan <33737564+Sg312@users.noreply.github.com>
2025-09-01 21:22:23 -07:00
Waleed
65e861822c fix(ui): dark mode styling for switch, trigger modal UI, signup/login improvements with auto-submit for OTP (#1214)
* fix(ui): fix dark mode styling for switch, fix trigger modal UI

* auto-submit OTP when characters are entered

* trim leading and trailing whitespace from name on signup, throw more informative error messages on reset pass
2025-09-01 21:19:12 -07:00
Siddharth Ganesan
12135d2aa8 improvement(copilot): improve context inputs and fix some bugs (#1216)
* Add logs v1

* Update

* Updates

* Updates

* Fixes

* Fix current workflow in context

* Fix mentions

* Error handling

* Fix chat loading

* Hide current workflow from context

* Run workflow fix

* Lint
2025-09-01 16:51:58 -07:00
Waleed
f75c807580 improvement(performance): added new indexes for improved session performance (#1215) 2025-09-01 16:00:15 -07:00
Vikhyath Mondreti
9ea7ea79e9 feat(workspace-vars): add workspace scoped environment + fix cancellation of assoc. workspace invites if org invite cancelled (#1208)
* feat(env-vars): workspace scoped environment variables

* fix cascade delete or workspace invite if org invite with attached workspace invites are created

* remove redundant refetch

* feat(env-vars): workspace scoped environment variables

* fix redirect for invitation error, remove check for validated emails on workspace invitation accept

* styling improvements

* remove random migration code

* stronger typing, added helpers, parallelized envvar encryption

---------

Co-authored-by: waleedlatif1 <walif6@gmail.com>
2025-09-01 15:56:58 -07:00
Waleed
5bbb349d8a fix(build): add missing pdf-parse dep, add docker build in staging (#1213)
* fix(build): add missing pdf-parse dep

* add docker build (no push) in staging
2025-09-01 13:04:16 -07:00
Waleed
ea09fcecb7 fix(build): consolidate pdf parsing dependencies, remove extraneous html deps (#1212)
* fix(build): consolidate pdf parsing dependencies, remove extraneous html deps

* add types
2025-09-01 10:19:24 -07:00
Waleed
9ccb7600f9 fix(organizations): remove org calls when billing is disabled (#1211) 2025-09-01 09:48:58 -07:00
Waleed
ee17cf461a v0.3.43: added additional parsers, mysql block improvements, billing fixes, permission fixes 2025-08-31 01:01:24 -07:00
Waleed
43cb124d97 fix(parsers): fix md, pptx, html kb uploads (#1209)
* fix md, pptx, html

* consolidate consts
2025-08-31 00:52:42 -07:00
Waleed
76889fde26 fix(permissions): remove permissions granted by org membership (#1206)
* fix(permissions): remove cross-functional permissions granted by org membership

* code hygiene
2025-08-30 18:14:01 -07:00
Vikhyath Mondreti
7780d9b32b fix(enterprise-billing): simplification to be fixed-cost (#1196)
* fix(enterprise-billing): simplify

* conceptual improvement

* add seats to enterprise sub meta

* correct type

* fix UI

* send emails to new enterprise users

* fix fallback

* fix merge conflict issue

---------

Co-authored-by: waleedlatif1 <walif6@gmail.com>
2025-08-30 17:26:17 -07:00
Waleed
4a703a02cb improvement(tools): update mysql to respect ssl pref (#1205) 2025-08-30 13:48:39 -07:00
Waleed
a969d09782 feat(parsers): added pptx, md, & html parsers (#1202)
* feat(parsers): added pptx, md, & html parsers

* ack PR comments

* file renaming, reorganization
2025-08-30 02:11:01 -07:00
Waleed
0bc778130f v0.3.42: kb config defaults, downgrade nextjs 2025-08-29 21:51:00 -07:00
Waleed
df3d532495 fix(deps): downgrade nextjs (#1200) 2025-08-29 21:44:51 -07:00
Waleed
f4f8fc051e improvement(kb): add fallbacks for kb configs (#1199) 2025-08-29 21:09:09 -07:00
Waleed
76fac13f3d v0.3.41: wand with azure openai, generic mysql and postgres blocks 2025-08-29 19:19:29 -07:00
Waleed
a3838302e0 feat(kb): add adjustable concurrency and batching to uploads and embeddings (#1198) 2025-08-29 18:37:23 -07:00
Waleed
4310dd6c15 imporvement(pg): added wand config for writing sql queries for generic db blocks & supabase postgrest syntax (#1197)
* add parallel ai, postgres, mysql, slight modifications to dark mode styling

* bun install frozen lockfile

* new deps

* improve security, add wand to short input and update wand config
2025-08-29 18:32:07 -07:00
Waleed
813a0fb741 feat(tools): add parallel ai, postgres, mysql, slight modifications to dark mode styling (#1192)
* add parallel ai, postgres, mysql, slight modifications to dark mode styling

* bun install frozen lockfile

* new deps
2025-08-29 17:25:02 -07:00
Waleed
7e23e942d7 fix(billing-ui): open settings when enterprise sub folks press usage indicator (#1194) 2025-08-29 16:11:32 -07:00
Siddharth Ganesan
7fcbafab97 Use direct fetch (#1193) 2025-08-29 16:10:36 -07:00
Siddharth Ganesan
056dc2879c Fix/wand (#1191)
* Switch to node

* Refactor
2025-08-29 15:50:26 -07:00
Siddharth Ganesan
1aec32b7e2 Switch to node (#1190) 2025-08-29 15:18:07 -07:00
Vikhyath Mondreti
316c9704af Merge pull request #1189 from simstudioai/staging
fix(deps): revert dependencies to before pg block was added
2025-08-29 14:28:31 -07:00
Vikhyath Mondreti
4e3a3bd1b1 run bun install 2025-08-29 14:23:31 -07:00
Vikhyath Mondreti
36773e8cdb Revert "feat(integrations): added parallel AI, mySQL, and postgres block/tools (#1126)"
This reverts commit 766279bb8b.
2025-08-29 14:14:45 -07:00
Vikhyath Mondreti
7ac89e35a1 revert(dep-changes): revert drizzle-orm version and change CI yaml script 2025-08-29 13:51:36 -07:00
Vikhyath Mondreti
faa094195a change bun install to be based on frozen-lockfile flag"
"
2025-08-29 13:42:20 -07:00
Vikhyath Mondreti
69319d21cd revert drizzle-orm version 2025-08-29 13:36:57 -07:00
Vikhyath Mondreti
8362fd7a83 remove bun lock 2025-08-29 13:34:46 -07:00
Vikhyath Mondreti
39ad793a9a revert package.json 2025-08-29 13:34:19 -07:00
Waleed
921c755711 v0.3.40: drizzle fixes, custom postgres port support 2025-08-29 10:24:40 -07:00
Waleed
41ec75fcad fix(pg): fix POSTGRES_PORT envvar to map external port to 5432 internally (#1187) 2025-08-29 10:11:37 -07:00
Waleed
f2502f5e48 fix(database): revert changes related to db URL (#1185)
* fix(database): revert changes related to db URL

* cleanup
2025-08-29 09:33:40 -07:00
Vikhyath Mondreti
f3c4f7e20a fix 2025-08-29 00:35:15 -07:00
Vikhyath Mondreti
f578f43c9a graceful exit for drizzle migration 2025-08-29 00:25:47 -07:00
Vikhyath Mondreti
5c73038023 fix(db): attempt parsing cert and db url separately (#1183) 2025-08-29 00:17:05 -07:00
Waleed
92132024ca fix(db): accept self-signed certs (#1181) 2025-08-28 23:19:43 -07:00
Waleed
ed11456de3 fix(db): accept self-signed certs (#1181) 2025-08-28 23:19:02 -07:00
Waleed
8739a3d378 fix(ssl): add envvar for optional ssl cert (#1179) 2025-08-28 23:11:21 -07:00
Waleed
ca015deea9 fix(ssl): add envvar for optional ssl cert (#1179) 2025-08-28 23:00:43 -07:00
Waleed
fd6d927228 v0.3.40: copilot improvements, knowledgebase improvements, security improvements, billing fixes 2025-08-28 22:00:58 -07:00
Adam Gough
6ac59a3264 Revert "fix(cursor-and-input): fixes cursor and input canvas error (#1168)" (#1178)
This reverts commit aa84c75360.
2025-08-28 21:06:30 -07:00
Adam Gough
aa84c75360 fix(cursor-and-input): fixes cursor and input canvas error (#1168)
* fixed long input

* lint

* fix gray canvas

* fixed auto-pan

* remove duplicate useEffect

* fix auto-pan for wide mode

* removed any

---------

Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
2025-08-28 20:17:10 -07:00
Vikhyath Mondreti
ebb8cf8bf9 fix(slack): set depends on for slack channel channel subblock (#1177)
* fix(slack): set depends on for slack channel

* use foreign credential check

* fix

* fix clearing of block
2025-08-28 20:11:30 -07:00
Siddharth Ganesan
cadfcdbfbd Fix (#1176) 2025-08-28 19:21:29 -07:00
Vikhyath Mondreti
7d62c200fa feat(openrouter): add open router to model block (#1172)
* feat(openrouter): add open router to model block

* improvement(openrouter): streaming fix, temperature fix

* pr comments

---------

Co-authored-by: waleedlatif1 <walif6@gmail.com>
2025-08-28 18:47:36 -07:00
Siddharth Ganesan
df646256b3 Revert "feat(debug): create debugger (#1174)" (#1175)
This reverts commit 7c73f5ffe0.
2025-08-28 18:46:40 -07:00
Siddharth Ganesan
7c73f5ffe0 feat(debug): create debugger (#1174)
* Updates

* Updates

* Updates

* Checkpoint

* Checkpoint

* Checkpoitn

* Var improvements

* Fixes

* Execution status

* UI improvements

* Ui updates

* Fix

* Fix scoping

* Fix workflow vars

* Fix env vars

* Remove number styling

* Variable highlighting

* Updates

* Update

* Fix resume

* Stuff

* Breakpoint ui

* Ui

* Ui updates

* Loops and parallels

* HIde env vars

* Checkpoint

* Stuff

* Panel toggle

* Lint
2025-08-28 18:19:20 -07:00
Waleed
bb5f40a027 feat(pg): added ability to customize postgres port when running containerized app (#1173) 2025-08-28 17:16:24 -07:00
Waleed
5ae5429296 chore(deps): upgrade trigger.dev in gh action (#1171) 2025-08-28 17:08:59 -07:00
Waleed
fcf128f6db improvement(knowledge): remove innerJoin and add id identifiers to results, updated docs (#1170)
* improvement(knowledge): remove innerJoin and add id identifiers to results, updated docs

* cleanup

* add documentName to upload chunk op as well
2025-08-28 17:04:31 -07:00
Vikhyath Mondreti
56543dafb4 fix(billing): usage tracking cleanup, shared pool of limits for team/enterprise (#1131)
* fix(billing): team usage tracking cleanup, shared pool of limits for team

* address greptile commments

* fix lint

* remove usage of deprecated cols"

* update periodStart and periodEnd correctly

* fix lint

* fix type issue

* fix(billing): cleaned up billing, still more work to do on UI and population of data and consolidation

* fix upgrade

* cleanup

* progress

* works

* Remove 78th migration to prepare for merge with staging

* fix migration conflict

* remove useless test file

* fix

* Fix undefined seat pricing display and handle cancelled subscription seat updates

* cleanup code

* cleanup to use helpers for pulling pricing limits

* cleanup more things

* cleanup

* restore environment ts file

* remove unused files

* fix(team-management): fix team management UI, consolidate components

* use session data instead of subscription data in settings navigation

* remove unused code

* fix UI for enterprise plans

* added enterprise plan support

* progress

* billing state machine

* split overage and base into separate invoices

* fix badge logic

---------

Co-authored-by: waleedlatif1 <walif6@gmail.com>
2025-08-28 17:00:48 -07:00
Emir Karabeg
7cc4574913 improvement(knowledge): search returns document name (#1167) 2025-08-28 16:07:22 -07:00
Waleed
3f900947ce improvement(kb): use trigger.dev for kb tasks (#1166) 2025-08-28 12:14:31 -07:00
Waleed
bda8ee772a fix(security): strengthen email invite validation logic, fix invite page UI (#1162)
* fix(security): strengthen email ivnite validation logic, fix invite page UI

* ui
2025-08-28 00:03:03 -07:00
Siddharth Ganesan
104d34cc9e fix(copilot): context filtering (#1160)
* Add filter

* Scope kb and chats

* Lint

* Remove comments

* Lint
2025-08-27 22:57:28 -07:00
Siddharth Ganesan
06e9a6b302 feat(copilot): context (#1157)
* Copilot updates

* Set/get vars

* Credentials opener v1

* Progress

* Checkpoint?

* Context v1

* Workflow references

* Add knowledge base context

* Blocks

* Templates

* Much better pills

* workflow updates

* Major ui

* Workflow box colors

* Much i mproved ui

* Improvements

* Much better

* Add @ icon

* Welcome page

* Update tool names

* Matches

* UPdate ordering

* Good sort

* Good @ handling

* Update placeholder

* Updates

* Lint

* Almost there

* Wrapped up?

* Lint

* Builid error fix

* Build fix?

* Lint

* Fix load vars
2025-08-27 21:07:51 -07:00
Waleed
fed4e507cc fix(signup): refetch session data on signup (#1155) 2025-08-27 20:01:04 -07:00
Waleed
389456e0f3 fix(envvars): fix split for pasting envvars with query params (#1156) 2025-08-27 19:55:54 -07:00
Vikhyath Mondreti
c720f23d9b fix(sockets): useCollabWorkflow cleanup, variables store logic simplification (#1154)
* fix(sockets): useCollabWorkflow cleanup, variables store logic simplification

* remove unecessary check
2025-08-27 17:11:39 -07:00
Vikhyath Mondreti
89f7d2b943 improvement(sockets): cleanup debounce logic + add flush mechanism to… (#1152)
* improvement(sockets): cleanup debounce logic + add flush mechanism to not lose ops

* fix optimistic update overwritten race condition

* fix

* fix forever stuck in processing
2025-08-27 11:35:20 -07:00
Emir Karabeg
923c05239c fix(auto-layout): revert (#1148) 2025-08-26 23:24:09 -07:00
Waleed
3424a338b7 fix(security): fixed SSRF vulnerability (#1149) 2025-08-26 23:11:08 -07:00
Waleed
51b1e97fa2 fix(kb-uploads): created knowledge, chunks, tags services and use redis for queueing docs in kb (#1143)
* improvement(kb): created knowledge, chunks, tags services and use redis for queueing docs in kb

* moved directories around

* cleanup

* bulk create docuemnt records after upload is completed

* fix(copilot): send api key to sim agent (#1142)

* Fix api key auth

* Lint

* ack PR comments

* added sort by functionality for headers in kb table

* updated

* test fallback from redis, fix styling

* cleanup copilot, fixed tooltips

* feat: local auto layout (#1144)

* feat: added llms.txt and robots.txt (#1145)

* fix(condition-block): edges not following blocks, duplicate issues (#1146)

* fix(condition-block): edges not following blocks, duplicate issues

* add subblock update to setActiveWorkflow

* Update apps/sim/app/workspace/[workspaceId]/w/[workflowId]/components/workflow-block/components/sub-block/components/condition-input.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* fix dependency array

* fix(copilot-cleanup): support azure blob upload in copilot, remove dead code & consolidate other copilot files (#1147)

* cleanup

* support azure blob image upload

* imports cleanup

* PR comments

* ack PR comments

* fix key validation

* improvement(forwarding+excel): added forwarding and improve excel read (#1136)

* added forwarding for outlook

* lint

* improved excel sheet read

* addressed greptile

* fixed bodytext getting truncated

* fixed any type

* added html func

---------

Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>

* revert agent const

* update docs

---------

Co-authored-by: Siddharth Ganesan <33737564+Sg312@users.noreply.github.com>
Co-authored-by: Emir Karabeg <78010029+emir-karabeg@users.noreply.github.com>
Co-authored-by: Vikhyath Mondreti <vikhyathvikku@gmail.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: Vikhyath Mondreti <vikhyath@simstudio.ai>
Co-authored-by: Adam Gough <77861281+aadamgough@users.noreply.github.com>
Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
2025-08-26 22:55:18 -07:00
Adam Gough
ab74b13802 improvement(forwarding+excel): added forwarding and improve excel read (#1136)
* added forwarding for outlook

* lint

* improved excel sheet read

* addressed greptile

* fixed bodytext getting truncated

* fixed any type

* added html func

---------

Co-authored-by: Adam Gough <adamgough@Mac.attlocal.net>
2025-08-26 21:18:09 -07:00
Vikhyath Mondreti
861ab1446a Merge branch 'staging' of github.com:simstudioai/sim into staging 2025-08-26 20:09:13 -07:00
Vikhyath Mondreti
e6f519a5a6 fix dependency array 2025-08-26 20:08:37 -07:00
Waleed
8226e7b40a fix(copilot-cleanup): support azure blob upload in copilot, remove dead code & consolidate other copilot files (#1147)
* cleanup

* support azure blob image upload

* imports cleanup

* PR comments

* ack PR comments

* fix key validation
2025-08-26 20:06:43 -07:00
Vikhyath Mondreti
b177b291cf fix(condition-block): edges not following blocks, duplicate issues (#1146)
* fix(condition-block): edges not following blocks, duplicate issues

* add subblock update to setActiveWorkflow

* Update apps/sim/app/workspace/[workspaceId]/w/[workflowId]/components/workflow-block/components/sub-block/components/condition-input.tsx

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-26 19:51:55 -07:00
Emir Karabeg
9c3b43325b feat: added llms.txt and robots.txt (#1145) 2025-08-26 19:04:27 -07:00
Emir Karabeg
973a5c6497 feat: local auto layout (#1144) 2025-08-26 19:03:09 -07:00
Siddharth Ganesan
78437c688e fix(copilot): send api key to sim agent (#1142)
* Fix api key auth

* Lint
2025-08-26 16:01:42 -07:00
Vikhyath Mondreti
3b74250335 fix(subblock-race-condition): check loading state correctly (#1141)
* fix(subblock-race-condition): check loading state correctly"
;

* clean up

* remove useless comments

* fix date fallback
2025-08-26 12:14:58 -07:00
Waleed
c68800c772 feat(login): add terms and privacy to signup and login pages (#1139) 2025-08-26 11:19:17 -07:00
Siddharth Ganesan
5403665fa9 Docs update (#1140) 2025-08-26 11:16:07 -07:00
Siddharth Ganesan
3d3443f68e fix(copilot): enterprise api keys (#1138)
* Copilot enterprise

* Fix validation and enterprise azure keys

* Lint

* update tests

* Update

* Lint

* Remove hardcoded ishosted

* Lint

* Updatse

* Add tests
2025-08-26 10:55:08 -07:00
Emir Karabeg
e5c0b14367 improvement(help-modal): ui/ux (#1135) 2025-08-25 19:36:38 -07:00
Siddharth Ganesan
a495516901 feat(copilot): enable azure openai and move key validation (#1134)
* Copilot enterprise

* Fix validation and enterprise azure keys

* Lint

* update tests

* Update

* Lint

* Remove hardcoded ishosted

* Lint
2025-08-25 18:03:08 -07:00
Waleed
1f9b4a8ef0 fix(wand): remove unstable__noStore and remove, add additional logs for wand generation (#1133)
* feat(wand): added additional logs for wand generation

* remove unstable__noStore
2025-08-25 16:20:41 -07:00
Waleed
3372829c30 fix(wand): remove edge runtime for wand (#1132) 2025-08-25 14:21:27 -07:00
Waleed
45372aece5 fix(files): fix vulnerabilities in file uploads/deletes (#1130)
* fix(vulnerability): fix arbitrary file deletion vuln

* fix(uploads): fix vuln during upload

* cleanup
2025-08-25 11:26:42 -07:00
481 changed files with 70180 additions and 15392 deletions

View File

@@ -77,7 +77,7 @@ services:
- POSTGRES_PASSWORD=postgres
- POSTGRES_DB=simstudio
ports:
- "5432:5432"
- "${POSTGRES_PORT:-5432}:5432"
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s

View File

@@ -2,8 +2,7 @@ name: Build and Publish Docker Image
on:
push:
branches: [main]
tags: ['v*']
branches: [main, staging]
jobs:
build-and-push:
@@ -56,7 +55,7 @@ jobs:
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
if: github.event_name != 'pull_request'
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
uses: docker/login-action@v3
with:
registry: ghcr.io
@@ -70,10 +69,7 @@ jobs:
images: ${{ matrix.image }}
tags: |
type=raw,value=latest-${{ matrix.arch }},enable=${{ github.ref == 'refs/heads/main' }}
type=ref,event=pr,suffix=-${{ matrix.arch }}
type=semver,pattern={{version}},suffix=-${{ matrix.arch }}
type=semver,pattern={{major}}.{{minor}},suffix=-${{ matrix.arch }}
type=semver,pattern={{major}}.{{minor}}.{{patch}},suffix=-${{ matrix.arch }}
type=raw,value=staging-${{ github.sha }}-${{ matrix.arch }},enable=${{ github.ref == 'refs/heads/staging' }}
type=sha,format=long,suffix=-${{ matrix.arch }}
- name: Build and push Docker image
@@ -82,7 +78,7 @@ jobs:
context: .
file: ${{ matrix.dockerfile }}
platforms: ${{ matrix.platform }}
push: ${{ github.event_name != 'pull_request' }}
push: ${{ github.event_name != 'pull_request' && github.ref == 'refs/heads/main' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha,scope=build-v3
@@ -93,7 +89,7 @@ jobs:
create-manifests:
runs-on: ubuntu-latest
needs: build-and-push
if: github.event_name != 'pull_request'
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
strategy:
matrix:
include:
@@ -119,10 +115,6 @@ jobs:
images: ${{ matrix.image }}
tags: |
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}.{{minor}}.{{patch}}
type=sha,format=long
- name: Create and push manifest

View File

@@ -26,7 +26,7 @@ jobs:
node-version: latest
- name: Install dependencies
run: bun install
run: bun install --frozen-lockfile
- name: Run tests with coverage
env:

View File

@@ -35,10 +35,10 @@ jobs:
- name: Deploy to Staging
if: github.ref == 'refs/heads/staging'
working-directory: ./apps/sim
run: npx --yes trigger.dev@4.0.0 deploy -e staging
run: npx --yes trigger.dev@4.0.1 deploy -e staging
- name: Deploy to Production
if: github.ref == 'refs/heads/main'
working-directory: ./apps/sim
run: npx --yes trigger.dev@4.0.0 deploy
run: npx --yes trigger.dev@4.0.1 deploy

View File

@@ -160,7 +160,6 @@ Copilot is a Sim-managed service. To use Copilot on a self-hosted instance:
- Go to https://sim.ai → Settings → Copilot and generate a Copilot API key
- Set `COPILOT_API_KEY` in your self-hosted environment to that value
- Host Sim on a publicly available DNS and set NEXT_PUBLIC_APP_URL and BETTER_AUTH_URL to that value ([ngrok](https://ngrok.com/))
## Tech Stack

View File

@@ -7,8 +7,6 @@ import { Callout } from 'fumadocs-ui/components/callout'
import { Card, Cards } from 'fumadocs-ui/components/card'
import { MessageCircle, Package, Zap, Infinity as InfinityIcon, Brain, BrainCircuit } from 'lucide-react'
## What is Copilot
Copilot is your in-editor assistant that helps you build, understand, and improve workflows. It can:
- **Explain**: Answer questions about Sim and your current workflow
@@ -18,35 +16,34 @@ Copilot is your in-editor assistant that helps you build, understand, and improv
<Callout type="info">
Copilot is a Sim-managed service. For self-hosted deployments, generate a Copilot API key in the hosted app (sim.ai → Settings → Copilot)
1. Go to [sim.ai](https://sim.ai) → Settings → Copilot and generate a Copilot API key
2. Set `COPILOT_API_KEY` in your self-hosted environment to that value
3. Host Sim on a publicly available DNS and set `NEXT_PUBLIC_APP_URL` and `BETTER_AUTH_URL` to that value (e.g., using ngrok)
2. Set `COPILOT_API_KEY` in your self-hosted environment to that value
</Callout>
## Modes
<Cards>
<Card title="Ask">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<MessageCircle className="h-4 w-4 text-muted-foreground" />
Ask
</span>
<div>
<p className="m-0 text-sm">
Q&A mode for explanations, guidance, and suggestions without making changes to your workflow.
</p>
</div>
}
>
<div className="m-0 text-sm">
Q&A mode for explanations, guidance, and suggestions without making changes to your workflow.
</div>
</Card>
<Card title="Agent">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<Package className="h-4 w-4 text-muted-foreground" />
Agent
</span>
<div>
<p className="m-0 text-sm">
Build-and-edit mode. Copilot proposes specific edits (add blocks, wire variables, tweak settings) and applies them when you approve.
</p>
</div>
}
>
<div className="m-0 text-sm">
Build-and-edit mode. Copilot proposes specific edits (add blocks, wire variables, tweak settings) and applies them when you approve.
</div>
</Card>
</Cards>
@@ -54,44 +51,71 @@ Copilot is your in-editor assistant that helps you build, understand, and improv
## Depth Levels
<Cards>
<Card title="Fast">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<Zap className="h-4 w-4 text-muted-foreground" />
Fast
</span>
<div>
<p className="m-0 text-sm">Quickest and cheapest. Best for small edits, simple workflows, and minor tweaks.</p>
</div>
</div>
}
>
<div className="m-0 text-sm">Quickest and cheapest. Best for small edits, simple workflows, and minor tweaks.</div>
</Card>
<Card title="Auto">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<InfinityIcon className="h-4 w-4 text-muted-foreground" />
Auto
</span>
<div>
<p className="m-0 text-sm">Balanced speed and reasoning. Recommended default for most tasks.</p>
</div>
</div>
}
>
<div className="m-0 text-sm">Balanced speed and reasoning. Recommended default for most tasks.</div>
</Card>
<Card title="Pro">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<Brain className="h-4 w-4 text-muted-foreground" />
Advanced
</span>
<div>
<p className="m-0 text-sm">More reasoning for larger workflows and complex edits while staying performant.</p>
</div>
</div>
}
>
<div className="m-0 text-sm">More reasoning for larger workflows and complex edits while staying performant.</div>
</Card>
<Card title="Max">
<div className="flex items-start gap-3">
<span className="mt-0.5 inline-flex h-8 w-8 items-center justify-center rounded-md border border-border/50 bg-muted/60">
<Card
title={
<span className="inline-flex items-center gap-2">
<BrainCircuit className="h-4 w-4 text-muted-foreground" />
Behemoth
</span>
<div>
<p className="m-0 text-sm">Maximum reasoning for deep planning, debugging, and complex architectural changes.</p>
</div>
</div>
}
>
<div className="m-0 text-sm">Maximum reasoning for deep planning, debugging, and complex architectural changes.</div>
</Card>
</Cards>
</Cards>
## Billing and Cost Calculation
### How Costs Are Calculated
Copilot usage is billed per token from the underlying LLM:
- **Input tokens**: billed at the provider's base rate (**at-cost**)
- **Output tokens**: billed at **1.5×** the provider's base output rate
```javascript
copilotCost = (inputTokens × inputPrice + outputTokens × (outputPrice × 1.5)) / 1,000,000
```
| Component | Rate Applied |
|----------|----------------------|
| Input | inputPrice |
| Output | outputPrice × 1.5 |
<Callout type="warning">
Pricing shown reflects rates as of September 4, 2025. Check provider documentation for current pricing.
</Callout>
<Callout type="info">
Model prices are per million tokens. The calculation divides by 1,000,000 to get the actual cost. See <a href="/execution/advanced#cost-calculation">Logging and Cost Calculation</a> for background and examples.
</Callout>

View File

@@ -33,6 +33,7 @@
"microsoft_planner",
"microsoft_teams",
"mistral_parse",
"mongodb",
"mysql",
"notion",
"onedrive",

View File

@@ -109,7 +109,7 @@ Read data from a Microsoft Excel spreadsheet
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `spreadsheetId` | string | Yes | The ID of the spreadsheet to read from |
| `range` | string | No | The range of cells to read from |
| `range` | string | No | The range of cells to read from. Accepts "SheetName!A1:B2" for explicit ranges or just "SheetName" to read the used range of that sheet. If omitted, reads the used range of the first sheet. |
#### Output

View File

@@ -0,0 +1,264 @@
---
title: MongoDB
description: Connect to MongoDB database
---
import { BlockInfoCard } from "@/components/ui/block-info-card"
<BlockInfoCard
type="mongodb"
color="#E0E0E0"
icon={true}
iconSvg={`<svg className="block-icon" xmlns='http://www.w3.org/2000/svg' viewBox='0 0 128 128'>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='currentColor'
d='M88.038 42.812c1.605 4.643 2.761 9.383 3.141 14.296.472 6.095.256 12.147-1.029 18.142-.035.165-.109.32-.164.48-.403.001-.814-.049-1.208.012-3.329.523-6.655 1.065-9.981 1.604-3.438.557-6.881 1.092-10.313 1.687-1.216.21-2.721-.041-3.212 1.641-.014.046-.154.054-.235.08l.166-10.051-.169-24.252 1.602-.275c2.62-.429 5.24-.864 7.862-1.281 3.129-.497 6.261-.98 9.392-1.465 1.381-.215 2.764-.412 4.148-.618z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#45A538'
d='M61.729 110.054c-1.69-1.453-3.439-2.842-5.059-4.37-8.717-8.222-15.093-17.899-18.233-29.566-.865-3.211-1.442-6.474-1.627-9.792-.13-2.322-.318-4.665-.154-6.975.437-6.144 1.325-12.229 3.127-18.147l.099-.138c.175.233.427.439.516.702 1.759 5.18 3.505 10.364 5.242 15.551 5.458 16.3 10.909 32.604 16.376 48.9.107.318.384.579.583.866l-.87 2.969z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#46A037'
d='M88.038 42.812c-1.384.206-2.768.403-4.149.616-3.131.485-6.263.968-9.392 1.465-2.622.417-5.242.852-7.862 1.281l-1.602.275-.012-1.045c-.053-.859-.144-1.717-.154-2.576-.069-5.478-.112-10.956-.18-16.434-.042-3.429-.105-6.857-.175-10.285-.043-2.13-.089-4.261-.185-6.388-.052-1.143-.236-2.28-.311-3.423-.042-.657.016-1.319.029-1.979.817 1.583 1.616 3.178 2.456 4.749 1.327 2.484 3.441 4.314 5.344 6.311 7.523 7.892 12.864 17.068 16.193 27.433z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#409433'
d='M65.036 80.753c.081-.026.222-.034.235-.08.491-1.682 1.996-1.431 3.212-1.641 3.432-.594 6.875-1.13 10.313-1.687 3.326-.539 6.652-1.081 9.981-1.604.394-.062.805-.011 1.208-.012-.622 2.22-1.112 4.488-1.901 6.647-.896 2.449-1.98 4.839-3.131 7.182a49.142 49.142 0 01-6.353 9.763c-1.919 2.308-4.058 4.441-6.202 6.548-1.185 1.165-2.582 2.114-3.882 3.161l-.337-.23-1.214-1.038-1.256-2.753a41.402 41.402 0 01-1.394-9.838l.023-.561.171-2.426c.057-.828.133-1.655.168-2.485.129-2.982.241-5.964.359-8.946z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#4FAA41'
d='M65.036 80.753c-.118 2.982-.23 5.964-.357 8.947-.035.83-.111 1.657-.168 2.485l-.765.289c-1.699-5.002-3.399-9.951-5.062-14.913-2.75-8.209-5.467-16.431-8.213-24.642a4498.887 4498.887 0 00-6.7-19.867c-.105-.31-.407-.552-.617-.826l4.896-9.002c.168.292.39.565.496.879a6167.476 6167.476 0 016.768 20.118c2.916 8.73 5.814 17.467 8.728 26.198.116.349.308.671.491 1.062l.67-.78-.167 10.052z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#4AA73C'
d='M43.155 32.227c.21.274.511.516.617.826a4498.887 4498.887 0 016.7 19.867c2.746 8.211 5.463 16.433 8.213 24.642 1.662 4.961 3.362 9.911 5.062 14.913l.765-.289-.171 2.426-.155.559c-.266 2.656-.49 5.318-.814 7.968-.163 1.328-.509 2.632-.772 3.947-.198-.287-.476-.548-.583-.866-5.467-16.297-10.918-32.6-16.376-48.9a3888.972 3888.972 0 00-5.242-15.551c-.089-.263-.34-.469-.516-.702l3.272-8.84z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#57AE47'
d='M65.202 70.702l-.67.78c-.183-.391-.375-.714-.491-1.062-2.913-8.731-5.812-17.468-8.728-26.198a6167.476 6167.476 0 00-6.768-20.118c-.105-.314-.327-.588-.496-.879l6.055-7.965c.191.255.463.482.562.769 1.681 4.921 3.347 9.848 5.003 14.778 1.547 4.604 3.071 9.215 4.636 13.813.105.308.47.526.714.786l.012 1.045c.058 8.082.115 16.167.171 24.251z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#60B24F'
d='M65.021 45.404c-.244-.26-.609-.478-.714-.786-1.565-4.598-3.089-9.209-4.636-13.813-1.656-4.93-3.322-9.856-5.003-14.778-.099-.287-.371-.514-.562-.769 1.969-1.928 3.877-3.925 5.925-5.764 1.821-1.634 3.285-3.386 3.352-5.968.003-.107.059-.214.145-.514l.519 1.306c-.013.661-.072 1.322-.029 1.979.075 1.143.259 2.28.311 3.423.096 2.127.142 4.258.185 6.388.069 3.428.132 6.856.175 10.285.067 5.478.111 10.956.18 16.434.008.861.098 1.718.152 2.577z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#A9AA88'
d='M62.598 107.085c.263-1.315.609-2.62.772-3.947.325-2.649.548-5.312.814-7.968l.066-.01.066.011a41.402 41.402 0 001.394 9.838c-.176.232-.425.439-.518.701-.727 2.05-1.412 4.116-2.143 6.166-.1.28-.378.498-.574.744l-.747-2.566.87-2.969z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#B6B598'
d='M62.476 112.621c.196-.246.475-.464.574-.744.731-2.05 1.417-4.115 2.143-6.166.093-.262.341-.469.518-.701l1.255 2.754c-.248.352-.59.669-.728 1.061l-2.404 7.059c-.099.283-.437.483-.663.722l-.695-3.985z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#C2C1A7'
d='M63.171 116.605c.227-.238.564-.439.663-.722l2.404-7.059c.137-.391.48-.709.728-1.061l1.215 1.037c-.587.58-.913 1.25-.717 2.097l-.369 1.208c-.168.207-.411.387-.494.624-.839 2.403-1.64 4.819-2.485 7.222-.107.305-.404.544-.614.812-.109-1.387-.22-2.771-.331-4.158z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#CECDB7'
d='M63.503 120.763c.209-.269.506-.508.614-.812.845-2.402 1.646-4.818 2.485-7.222.083-.236.325-.417.494-.624l-.509 5.545c-.136.157-.333.294-.398.477-.575 1.614-1.117 3.24-1.694 4.854-.119.333-.347.627-.525.938-.158-.207-.441-.407-.454-.623-.051-.841-.016-1.688-.013-2.533z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#DBDAC7'
d='M63.969 123.919c.178-.312.406-.606.525-.938.578-1.613 1.119-3.239 1.694-4.854.065-.183.263-.319.398-.477l.012 3.64-1.218 3.124-1.411-.495z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#EBE9DC'
d='M65.38 124.415l1.218-3.124.251 3.696-1.469-.572z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#CECDB7'
d='M67.464 110.898c-.196-.847.129-1.518.717-2.097l.337.23-1.054 1.867z'
/>
<path
fillRule='evenodd'
clipRule='evenodd'
fill='#4FAA41'
d='M64.316 95.172l-.066-.011-.066.01.155-.559-.023.56z'
/>
</svg>`}
/>
## Usage Instructions
Connect to any MongoDB database to execute queries, manage data, and perform database operations. Supports find, insert, update, delete, and aggregation operations with secure connection handling.
## Tools
### `mongodb_query`
Execute find operation on MongoDB collection
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | MongoDB server hostname or IP address |
| `port` | number | Yes | MongoDB server port \(default: 27017\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | No | MongoDB username |
| `password` | string | No | MongoDB password |
| `authSource` | string | No | Authentication database |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `collection` | string | Yes | Collection name to query |
| `query` | string | No | MongoDB query filter as JSON string |
| `limit` | number | No | Maximum number of documents to return |
| `sort` | string | No | Sort criteria as JSON string |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `documents` | array | Array of documents returned from the query |
| `documentCount` | number | Number of documents returned |
### `mongodb_insert`
Insert documents into MongoDB collection
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | MongoDB server hostname or IP address |
| `port` | number | Yes | MongoDB server port \(default: 27017\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | No | MongoDB username |
| `password` | string | No | MongoDB password |
| `authSource` | string | No | Authentication database |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `collection` | string | Yes | Collection name to insert into |
| `documents` | array | Yes | Array of documents to insert |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `documentCount` | number | Number of documents inserted |
| `insertedId` | string | ID of inserted document \(single insert\) |
| `insertedIds` | array | Array of inserted document IDs \(multiple insert\) |
### `mongodb_update`
Update documents in MongoDB collection
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | MongoDB server hostname or IP address |
| `port` | number | Yes | MongoDB server port \(default: 27017\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | No | MongoDB username |
| `password` | string | No | MongoDB password |
| `authSource` | string | No | Authentication database |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `collection` | string | Yes | Collection name to update |
| `filter` | string | Yes | Filter criteria as JSON string |
| `update` | string | Yes | Update operations as JSON string |
| `upsert` | boolean | No | Create document if not found |
| `multi` | boolean | No | Update multiple documents |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `matchedCount` | number | Number of documents matched by filter |
| `modifiedCount` | number | Number of documents modified |
| `documentCount` | number | Total number of documents affected |
| `insertedId` | string | ID of inserted document \(if upsert\) |
### `mongodb_delete`
Delete documents from MongoDB collection
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | MongoDB server hostname or IP address |
| `port` | number | Yes | MongoDB server port \(default: 27017\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | No | MongoDB username |
| `password` | string | No | MongoDB password |
| `authSource` | string | No | Authentication database |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `collection` | string | Yes | Collection name to delete from |
| `filter` | string | Yes | Filter criteria as JSON string |
| `multi` | boolean | No | Delete multiple documents |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `deletedCount` | number | Number of documents deleted |
| `documentCount` | number | Total number of documents affected |
### `mongodb_execute`
Execute MongoDB aggregation pipeline
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `host` | string | Yes | MongoDB server hostname or IP address |
| `port` | number | Yes | MongoDB server port \(default: 27017\) |
| `database` | string | Yes | Database name to connect to |
| `username` | string | No | MongoDB username |
| `password` | string | No | MongoDB password |
| `authSource` | string | No | Authentication database |
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
| `collection` | string | Yes | Collection name to execute pipeline on |
| `pipeline` | string | Yes | Aggregation pipeline as JSON string |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Operation status message |
| `documents` | array | Array of documents returned from aggregation |
| `documentCount` | number | Number of documents returned |
## Notes
- Category: `tools`
- Type: `mongodb`

View File

@@ -68,7 +68,7 @@ Upload a file to OneDrive
| `fileName` | string | Yes | The name of the file to upload |
| `content` | string | Yes | The content of the file to upload |
| `folderSelector` | string | No | Select the folder to upload the file to |
| `folderId` | string | No | The ID of the folder to upload the file to \(internal use\) |
| `manualFolderId` | string | No | Manually entered folder ID \(advanced mode\) |
#### Output
@@ -87,7 +87,7 @@ Create a new folder in OneDrive
| --------- | ---- | -------- | ----------- |
| `folderName` | string | Yes | Name of the folder to create |
| `folderSelector` | string | No | Select the parent folder to create the folder in |
| `folderId` | string | No | ID of the parent folder \(internal use\) |
| `manualFolderId` | string | No | Manually entered parent folder ID \(advanced mode\) |
#### Output
@@ -105,7 +105,7 @@ List files and folders in OneDrive
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `folderSelector` | string | No | Select the folder to list files from |
| `folderId` | string | No | The ID of the folder to list files from \(internal use\) |
| `manualFolderId` | string | No | The manually entered folder ID \(advanced mode\) |
| `query` | string | No | A query to filter the files |
| `pageSize` | number | No | The number of files to return |

View File

@@ -211,10 +211,27 @@ Read emails from Outlook
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `success` | boolean | Email read operation success status |
| `messageCount` | number | Number of emails retrieved |
| `messages` | array | Array of email message objects |
| `message` | string | Success or status message |
| `results` | array | Array of email message objects |
### `outlook_forward`
Forward an existing Outlook message to specified recipients
#### Input
| Parameter | Type | Required | Description |
| --------- | ---- | -------- | ----------- |
| `messageId` | string | Yes | The ID of the message to forward |
| `to` | string | Yes | Recipient email address\(es\), comma-separated |
| `comment` | string | No | Optional comment to include with the forwarded message |
#### Output
| Parameter | Type | Description |
| --------- | ---- | ----------- |
| `message` | string | Success or error message |
| `results` | object | Delivery result details |

View File

@@ -1,6 +1,9 @@
# Database (Required)
DATABASE_URL="postgresql://postgres:password@localhost:5432/postgres"
# PostgreSQL Port (Optional) - defaults to 5432 if not specified
# POSTGRES_PORT=5432
# Authentication (Required)
BETTER_AUTH_SECRET=your_secret_key # Use `openssl rand -hex 32` to generate, or visit https://www.better-auth.com/docs/installation
BETTER_AUTH_URL=http://localhost:3000

View File

@@ -49,15 +49,12 @@ const PASSWORD_VALIDATIONS = {
},
}
// Validate callback URL to prevent open redirect vulnerabilities
const validateCallbackUrl = (url: string): boolean => {
try {
// If it's a relative URL, it's safe
if (url.startsWith('/')) {
return true
}
// If absolute URL, check if it belongs to the same origin
const currentOrigin = typeof window !== 'undefined' ? window.location.origin : ''
if (url.startsWith(currentOrigin)) {
return true
@@ -70,7 +67,6 @@ const validateCallbackUrl = (url: string): boolean => {
}
}
// Validate password and return array of error messages
const validatePassword = (passwordValue: string): string[] => {
const errors: string[] = []
@@ -308,6 +304,15 @@ export default function LoginPage({
return
}
const emailValidation = quickValidateEmail(forgotPasswordEmail.trim().toLowerCase())
if (!emailValidation.isValid) {
setResetStatus({
type: 'error',
message: 'Please enter a valid email address',
})
return
}
try {
setIsSubmittingReset(true)
setResetStatus({ type: null, message: '' })
@@ -325,7 +330,23 @@ export default function LoginPage({
if (!response.ok) {
const errorData = await response.json()
throw new Error(errorData.message || 'Failed to request password reset')
let errorMessage = errorData.message || 'Failed to request password reset'
if (
errorMessage.includes('Invalid body parameters') ||
errorMessage.includes('invalid email')
) {
errorMessage = 'Please enter a valid email address'
} else if (errorMessage.includes('Email is required')) {
errorMessage = 'Please enter your email address'
} else if (
errorMessage.includes('user not found') ||
errorMessage.includes('User not found')
) {
errorMessage = 'No account found with this email address'
}
throw new Error(errorMessage)
}
setResetStatus({
@@ -475,6 +496,23 @@ export default function LoginPage({
Sign up
</Link>
</div>
<div className='text-center text-neutral-500/80 text-xs leading-relaxed'>
By signing in, you agree to our{' '}
<Link
href='/terms'
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
>
Terms of Service
</Link>{' '}
and{' '}
<Link
href='/privacy'
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
>
Privacy Policy
</Link>
</div>
</div>
<Dialog open={forgotPasswordOpen} onOpenChange={setForgotPasswordOpen}>
@@ -484,7 +522,8 @@ export default function LoginPage({
Reset Password
</DialogTitle>
<DialogDescription className='text-neutral-300 text-sm'>
Enter your email address and we'll send you a link to reset your password.
Enter your email address and we'll send you a link to reset your password if your
account exists.
</DialogDescription>
</DialogHeader>
<div className='space-y-4'>
@@ -499,16 +538,20 @@ export default function LoginPage({
placeholder='Enter your email'
required
type='email'
className='border-neutral-700/80 bg-neutral-900 text-white placeholder:text-white/60 focus:border-[var(--brand-primary-hover-hex)]/70 focus:ring-[var(--brand-primary-hover-hex)]/20'
className={cn(
'border-neutral-700/80 bg-neutral-900 text-white placeholder:text-white/60 focus:border-[var(--brand-primary-hover-hex)]/70 focus:ring-[var(--brand-primary-hover-hex)]/20',
resetStatus.type === 'error' && 'border-red-500 focus-visible:ring-red-500'
)}
/>
{resetStatus.type === 'error' && (
<div className='mt-1 space-y-1 text-red-400 text-xs'>
<p>{resetStatus.message}</p>
</div>
)}
</div>
{resetStatus.type && (
<div
className={`text-sm ${
resetStatus.type === 'success' ? 'text-[#4CAF50]' : 'text-red-500'
}`}
>
{resetStatus.message}
{resetStatus.type === 'success' && (
<div className='mt-1 space-y-1 text-[#4CAF50] text-xs'>
<p>{resetStatus.message}</p>
</div>
)}
<Button

View File

@@ -5,7 +5,7 @@
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { useRouter, useSearchParams } from 'next/navigation'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { client } from '@/lib/auth-client'
import { client, useSession } from '@/lib/auth-client'
import SignupPage from '@/app/(auth)/signup/signup-form'
vi.mock('next/navigation', () => ({
@@ -22,6 +22,7 @@ vi.mock('@/lib/auth-client', () => ({
sendVerificationOtp: vi.fn(),
},
},
useSession: vi.fn(),
}))
vi.mock('@/app/(auth)/components/social-login-buttons', () => ({
@@ -43,6 +44,9 @@ describe('SignupPage', () => {
vi.clearAllMocks()
;(useRouter as any).mockReturnValue(mockRouter)
;(useSearchParams as any).mockReturnValue(mockSearchParams)
;(useSession as any).mockReturnValue({
refetch: vi.fn().mockResolvedValue({}),
})
mockSearchParams.get.mockReturnValue(null)
})
@@ -162,8 +166,9 @@ describe('SignupPage', () => {
})
})
it('should prevent submission with invalid name validation', async () => {
it('should automatically trim spaces from name input', async () => {
const mockSignUp = vi.mocked(client.signUp.email)
mockSignUp.mockResolvedValue({ data: null, error: null })
render(<SignupPage {...defaultProps} />)
@@ -172,22 +177,20 @@ describe('SignupPage', () => {
const passwordInput = screen.getByPlaceholderText(/enter your password/i)
const submitButton = screen.getByRole('button', { name: /create account/i })
// Use name with leading/trailing spaces which should fail validation
fireEvent.change(nameInput, { target: { value: ' John Doe ' } })
fireEvent.change(emailInput, { target: { value: 'user@company.com' } })
fireEvent.change(passwordInput, { target: { value: 'Password123!' } })
fireEvent.click(submitButton)
// Should not call signUp because validation failed
expect(mockSignUp).not.toHaveBeenCalled()
// Should show validation error
await waitFor(() => {
expect(
screen.getByText(
/Name cannot contain consecutive spaces|Name cannot start or end with spaces/
)
).toBeInTheDocument()
expect(mockSignUp).toHaveBeenCalledWith(
expect.objectContaining({
name: 'John Doe',
email: 'user@company.com',
password: 'Password123!',
}),
expect.any(Object)
)
})
})

View File

@@ -7,7 +7,7 @@ import { useRouter, useSearchParams } from 'next/navigation'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { client } from '@/lib/auth-client'
import { client, useSession } from '@/lib/auth-client'
import { quickValidateEmail } from '@/lib/email/validation'
import { createLogger } from '@/lib/logs/console/logger'
import { cn } from '@/lib/utils'
@@ -49,10 +49,6 @@ const NAME_VALIDATIONS = {
regex: /^(?!.*\s\s).*$/,
message: 'Name cannot contain consecutive spaces.',
},
noLeadingTrailingSpaces: {
test: (value: string) => value === value.trim(),
message: 'Name cannot start or end with spaces.',
},
}
const validateEmailField = (emailValue: string): string[] => {
@@ -82,6 +78,7 @@ function SignupFormContent({
}) {
const router = useRouter()
const searchParams = useSearchParams()
const { refetch: refetchSession } = useSession()
const [isLoading, setIsLoading] = useState(false)
const [, setMounted] = useState(false)
const [showPassword, setShowPassword] = useState(false)
@@ -174,10 +171,6 @@ function SignupFormContent({
errors.push(NAME_VALIDATIONS.noConsecutiveSpaces.message)
}
if (!NAME_VALIDATIONS.noLeadingTrailingSpaces.test(nameValue)) {
errors.push(NAME_VALIDATIONS.noLeadingTrailingSpaces.message)
}
return errors
}
@@ -192,11 +185,10 @@ function SignupFormContent({
}
const handleNameChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const newName = e.target.value
setName(newName)
const rawValue = e.target.value
setName(rawValue)
// Silently validate but don't show errors until submit
const errors = validateName(newName)
const errors = validateName(rawValue)
setNameErrors(errors)
setShowNameValidationError(false)
}
@@ -223,23 +215,21 @@ function SignupFormContent({
const formData = new FormData(e.currentTarget)
const emailValue = formData.get('email') as string
const passwordValue = formData.get('password') as string
const name = formData.get('name') as string
const nameValue = formData.get('name') as string
// Validate name on submit
const nameValidationErrors = validateName(name)
const trimmedName = nameValue.trim()
const nameValidationErrors = validateName(trimmedName)
setNameErrors(nameValidationErrors)
setShowNameValidationError(nameValidationErrors.length > 0)
// Validate email on submit
const emailValidationErrors = validateEmailField(emailValue)
setEmailErrors(emailValidationErrors)
setShowEmailValidationError(emailValidationErrors.length > 0)
// Validate password on submit
const errors = validatePassword(passwordValue)
setPasswordErrors(errors)
// Only show validation errors if there are any
setShowValidationError(errors.length > 0)
try {
@@ -248,7 +238,6 @@ function SignupFormContent({
emailValidationErrors.length > 0 ||
errors.length > 0
) {
// Prioritize name errors first, then email errors, then password errors
if (nameValidationErrors.length > 0) {
setNameErrors([nameValidationErrors[0]])
setShowNameValidationError(true)
@@ -265,8 +254,6 @@ function SignupFormContent({
return
}
// Check if name will be truncated and warn user
const trimmedName = name.trim()
if (trimmedName.length > 100) {
setNameErrors(['Name will be truncated to 100 characters. Please shorten your name.'])
setShowNameValidationError(true)
@@ -330,6 +317,14 @@ function SignupFormContent({
return
}
// Refresh session to get the new user data immediately after signup
try {
await refetchSession()
logger.info('Session refreshed after successful signup')
} catch (sessionError) {
logger.error('Failed to refresh session after signup:', sessionError)
}
// For new signups, always require verification
if (typeof window !== 'undefined') {
sessionStorage.setItem('verificationEmail', emailValue)
@@ -507,6 +502,23 @@ function SignupFormContent({
Sign in
</Link>
</div>
<div className='text-center text-neutral-500/80 text-xs leading-relaxed'>
By creating an account, you agree to our{' '}
<Link
href='/terms'
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
>
Terms of Service
</Link>{' '}
and{' '}
<Link
href='/privacy'
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
>
Privacy Policy
</Link>
</div>
</div>
</div>
)

View File

@@ -2,7 +2,7 @@
import { useEffect, useState } from 'react'
import { useRouter, useSearchParams } from 'next/navigation'
import { client } from '@/lib/auth-client'
import { client, useSession } from '@/lib/auth-client'
import { env, isTruthy } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
@@ -34,6 +34,7 @@ export function useVerification({
}: UseVerificationParams): UseVerificationReturn {
const router = useRouter()
const searchParams = useSearchParams()
const { refetch: refetchSession } = useSession()
const [otp, setOtp] = useState('')
const [email, setEmail] = useState('')
const [isLoading, setIsLoading] = useState(false)
@@ -136,16 +137,15 @@ export function useVerification({
}
}
// Redirect to proper page after a short delay
setTimeout(() => {
if (isInviteFlow && redirectUrl) {
// For invitation flow, redirect to the invitation page
router.push(redirectUrl)
window.location.href = redirectUrl
} else {
// Default redirect to dashboard
router.push('/workspace')
window.location.href = '/workspace'
}
}, 2000)
}, 1000)
} else {
logger.info('Setting invalid OTP state - API error response')
const message = 'Invalid verification code. Please check and try again.'
@@ -215,25 +215,33 @@ export function useVerification({
setOtp(value)
}
// Auto-submit when OTP is complete
useEffect(() => {
if (otp.length === 6 && email && !isLoading && !isVerified) {
const timeoutId = setTimeout(() => {
verifyCode()
}, 300) // Small delay to ensure UI is ready
return () => clearTimeout(timeoutId)
}
}, [otp, email, isLoading, isVerified])
useEffect(() => {
if (typeof window !== 'undefined') {
if (!isProduction || !hasResendKey) {
const storedEmail = sessionStorage.getItem('verificationEmail')
logger.info('Auto-verifying user', { email: storedEmail })
}
const isDevOrDocker = !isProduction || isTruthy(env.DOCKER_BUILD)
// Auto-verify and redirect in development/docker environments
if (isDevOrDocker || !hasResendKey) {
setIsVerified(true)
// Clear verification requirement cookie (same as manual verification)
document.cookie =
'requiresEmailVerification=; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT'
const timeoutId = setTimeout(() => {
router.push('/workspace')
window.location.href = '/workspace'
}, 1000)
return () => clearTimeout(timeoutId)

View File

@@ -0,0 +1,7 @@
import { toNextJsHandler } from 'better-auth/next-js'
import { auth } from '@/lib/auth'
export const dynamic = 'force-dynamic'
// Handle Stripe webhooks through better-auth
export const { GET, POST } = toNextJsHandler(auth.handler)

View File

@@ -0,0 +1,77 @@
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { requireStripeClient } from '@/lib/billing/stripe-client'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { subscription as subscriptionTable, user } from '@/db/schema'
const logger = createLogger('BillingPortal')
export async function POST(request: NextRequest) {
const session = await getSession()
try {
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const body = await request.json().catch(() => ({}))
const context: 'user' | 'organization' =
body?.context === 'organization' ? 'organization' : 'user'
const organizationId: string | undefined = body?.organizationId || undefined
const returnUrl: string =
body?.returnUrl || `${env.NEXT_PUBLIC_APP_URL}/workspace?billing=updated`
const stripe = requireStripeClient()
let stripeCustomerId: string | null = null
if (context === 'organization') {
if (!organizationId) {
return NextResponse.json({ error: 'organizationId is required' }, { status: 400 })
}
const rows = await db
.select({ customer: subscriptionTable.stripeCustomerId })
.from(subscriptionTable)
.where(
and(
eq(subscriptionTable.referenceId, organizationId),
eq(subscriptionTable.status, 'active')
)
)
.limit(1)
stripeCustomerId = rows.length > 0 ? rows[0].customer || null : null
} else {
const rows = await db
.select({ customer: user.stripeCustomerId })
.from(user)
.where(eq(user.id, session.user.id))
.limit(1)
stripeCustomerId = rows.length > 0 ? rows[0].customer || null : null
}
if (!stripeCustomerId) {
logger.error('Stripe customer not found for portal session', {
context,
organizationId,
userId: session.user.id,
})
return NextResponse.json({ error: 'Stripe customer not found' }, { status: 404 })
}
const portal = await stripe.billingPortal.sessions.create({
customer: stripeCustomerId,
return_url: returnUrl,
})
return NextResponse.json({ url: portal.url })
} catch (error) {
logger.error('Failed to create billing portal session', { error })
return NextResponse.json({ error: 'Failed to create billing portal session' }, { status: 500 })
}
}

View File

@@ -2,10 +2,10 @@ import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { getSimplifiedBillingSummary } from '@/lib/billing/core/billing'
import { getOrganizationBillingData } from '@/lib/billing/core/organization-billing'
import { getOrganizationBillingData } from '@/lib/billing/core/organization'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { member } from '@/db/schema'
import { member, userStats } from '@/db/schema'
const logger = createLogger('UnifiedBillingAPI')
@@ -45,6 +45,16 @@ export async function GET(request: NextRequest) {
if (context === 'user') {
// Get user billing (may include organization if they're part of one)
billingData = await getSimplifiedBillingSummary(session.user.id, contextId || undefined)
// Attach billingBlocked status for the current user
const stats = await db
.select({ blocked: userStats.billingBlocked })
.from(userStats)
.where(eq(userStats.userId, session.user.id))
.limit(1)
billingData = {
...billingData,
billingBlocked: stats.length > 0 ? !!stats[0].blocked : false,
}
} else {
// Get user role in organization for permission checks first
const memberRecord = await db
@@ -78,8 +88,10 @@ export async function GET(request: NextRequest) {
subscriptionStatus: rawBillingData.subscriptionStatus,
totalSeats: rawBillingData.totalSeats,
usedSeats: rawBillingData.usedSeats,
seatsCount: rawBillingData.seatsCount,
totalCurrentUsage: rawBillingData.totalCurrentUsage,
totalUsageLimit: rawBillingData.totalUsageLimit,
minimumBillingAmount: rawBillingData.minimumBillingAmount,
averageUsagePerMember: rawBillingData.averageUsagePerMember,
billingPeriodStart: rawBillingData.billingPeriodStart?.toISOString() || null,
billingPeriodEnd: rawBillingData.billingPeriodEnd?.toISOString() || null,
@@ -92,11 +104,25 @@ export async function GET(request: NextRequest) {
const userRole = memberRecord[0].role
// Include the requesting user's blocked flag as well so UI can reflect it
const stats = await db
.select({ blocked: userStats.billingBlocked })
.from(userStats)
.where(eq(userStats.userId, session.user.id))
.limit(1)
// Merge blocked flag into data for convenience
billingData = {
...billingData,
billingBlocked: stats.length > 0 ? !!stats[0].blocked : false,
}
return NextResponse.json({
success: true,
context,
data: billingData,
userRole,
billingBlocked: billingData.billingBlocked,
})
}

View File

@@ -16,7 +16,8 @@ const UpdateCostSchema = z.object({
input: z.number().min(0, 'Input tokens must be a non-negative number'),
output: z.number().min(0, 'Output tokens must be a non-negative number'),
model: z.string().min(1, 'Model is required'),
multiplier: z.number().min(0),
inputMultiplier: z.number().min(0),
outputMultiplier: z.number().min(0),
})
/**
@@ -75,14 +76,15 @@ export async function POST(req: NextRequest) {
)
}
const { userId, input, output, model, multiplier } = validation.data
const { userId, input, output, model, inputMultiplier, outputMultiplier } = validation.data
logger.info(`[${requestId}] Processing cost update`, {
userId,
input,
output,
model,
multiplier,
inputMultiplier,
outputMultiplier,
})
const finalPromptTokens = input
@@ -95,7 +97,8 @@ export async function POST(req: NextRequest) {
finalPromptTokens,
finalCompletionTokens,
false,
multiplier
inputMultiplier,
outputMultiplier
)
logger.info(`[${requestId}] Cost calculation result`, {
@@ -104,7 +107,8 @@ export async function POST(req: NextRequest) {
promptTokens: finalPromptTokens,
completionTokens: finalCompletionTokens,
totalTokens: totalTokens,
multiplier,
inputMultiplier,
outputMultiplier,
costResult,
})
@@ -115,52 +119,34 @@ export async function POST(req: NextRequest) {
const userStatsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))
if (userStatsRecords.length === 0) {
// Create new user stats record (same logic as ExecutionLogger)
await db.insert(userStats).values({
id: crypto.randomUUID(),
userId: userId,
totalManualExecutions: 0,
totalApiCalls: 0,
totalWebhookTriggers: 0,
totalScheduledExecutions: 0,
totalChatExecutions: 0,
totalTokensUsed: totalTokens,
totalCost: costToStore.toString(),
currentPeriodCost: costToStore.toString(),
// Copilot usage tracking
totalCopilotCost: costToStore.toString(),
totalCopilotTokens: totalTokens,
totalCopilotCalls: 1,
lastActive: new Date(),
})
logger.info(`[${requestId}] Created new user stats record`, {
userId,
totalCost: costToStore,
totalTokens,
})
} else {
// Update existing user stats record (same logic as ExecutionLogger)
const updateFields = {
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
totalCost: sql`total_cost + ${costToStore}`,
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
// Copilot usage tracking increments
totalCopilotCost: sql`total_copilot_cost + ${costToStore}`,
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
totalCopilotCalls: sql`total_copilot_calls + 1`,
totalApiCalls: sql`total_api_calls`,
lastActive: new Date(),
}
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
logger.info(`[${requestId}] Updated user stats record`, {
userId,
addedCost: costToStore,
addedTokens: totalTokens,
})
logger.error(
`[${requestId}] User stats record not found - should be created during onboarding`,
{
userId,
}
)
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
}
// Update existing user stats record (same logic as ExecutionLogger)
const updateFields = {
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
totalCost: sql`total_cost + ${costToStore}`,
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
// Copilot usage tracking increments
totalCopilotCost: sql`total_copilot_cost + ${costToStore}`,
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
totalCopilotCalls: sql`total_copilot_calls + 1`,
totalApiCalls: sql`total_api_calls`,
lastActive: new Date(),
}
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
logger.info(`[${requestId}] Updated user stats record`, {
userId,
addedCost: costToStore,
addedTokens: totalTokens,
})
const duration = Date.now() - startTime

View File

@@ -1,116 +0,0 @@
import { headers } from 'next/headers'
import { type NextRequest, NextResponse } from 'next/server'
import type Stripe from 'stripe'
import { requireStripeClient } from '@/lib/billing/stripe-client'
import { handleInvoiceWebhook } from '@/lib/billing/webhooks/stripe-invoice-webhooks'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('StripeInvoiceWebhook')
/**
* Stripe billing webhook endpoint for invoice-related events
* Endpoint: /api/billing/webhooks/stripe
* Handles: invoice.payment_succeeded, invoice.payment_failed, invoice.finalized
*/
export async function POST(request: NextRequest) {
try {
const body = await request.text()
const headersList = await headers()
const signature = headersList.get('stripe-signature')
if (!signature) {
logger.error('Missing Stripe signature header')
return NextResponse.json({ error: 'Missing Stripe signature' }, { status: 400 })
}
if (!env.STRIPE_BILLING_WEBHOOK_SECRET) {
logger.error('Missing Stripe webhook secret configuration')
return NextResponse.json({ error: 'Webhook secret not configured' }, { status: 500 })
}
// Check if Stripe client is available
let stripe
try {
stripe = requireStripeClient()
} catch (stripeError) {
logger.error('Stripe client not available for webhook processing', {
error: stripeError,
})
return NextResponse.json({ error: 'Stripe client not configured' }, { status: 500 })
}
// Verify webhook signature
let event: Stripe.Event
try {
event = stripe.webhooks.constructEvent(body, signature, env.STRIPE_BILLING_WEBHOOK_SECRET)
} catch (signatureError) {
logger.error('Invalid Stripe webhook signature', {
error: signatureError,
signature,
})
return NextResponse.json({ error: 'Invalid signature' }, { status: 400 })
}
logger.info('Received Stripe invoice webhook', {
eventId: event.id,
eventType: event.type,
})
// Handle specific invoice events
const supportedEvents = [
'invoice.payment_succeeded',
'invoice.payment_failed',
'invoice.finalized',
]
if (supportedEvents.includes(event.type)) {
try {
await handleInvoiceWebhook(event)
logger.info('Successfully processed invoice webhook', {
eventId: event.id,
eventType: event.type,
})
return NextResponse.json({ received: true })
} catch (processingError) {
logger.error('Failed to process invoice webhook', {
eventId: event.id,
eventType: event.type,
error: processingError,
})
// Return 500 to tell Stripe to retry the webhook
return NextResponse.json({ error: 'Failed to process webhook' }, { status: 500 })
}
} else {
// Not a supported invoice event, ignore
logger.info('Ignoring unsupported webhook event', {
eventId: event.id,
eventType: event.type,
supportedEvents,
})
return NextResponse.json({ received: true })
}
} catch (error) {
logger.error('Fatal error in invoice webhook handler', {
error,
url: request.url,
})
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
}
}
/**
* GET endpoint for webhook health checks
*/
export async function GET() {
return NextResponse.json({
status: 'healthy',
webhook: 'stripe-invoices',
events: ['invoice.payment_succeeded', 'invoice.payment_failed', 'invoice.finalized'],
})
}

View File

@@ -45,6 +45,7 @@ export async function GET(request: Request) {
'support',
'admin',
'qa',
'agent',
]
if (reservedSubdomains.includes(subdomain)) {
return NextResponse.json(

View File

@@ -3,6 +3,7 @@ import { type NextRequest, NextResponse } from 'next/server'
import { v4 as uuidv4 } from 'uuid'
import { checkServerSideUsageLimits } from '@/lib/billing'
import { isDev } from '@/lib/environment'
import { getPersonalAndWorkspaceEnv } from '@/lib/environment/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { LoggingSession } from '@/lib/logs/execution/logging-session'
import { buildTraceSpans } from '@/lib/logs/execution/trace-spans/trace-spans'
@@ -12,7 +13,7 @@ import { getEmailDomain } from '@/lib/urls/utils'
import { decryptSecret } from '@/lib/utils'
import { getBlock } from '@/blocks'
import { db } from '@/db'
import { chat, environment as envTable, userStats, workflow } from '@/db/schema'
import { chat, userStats, workflow } from '@/db/schema'
import { Executor } from '@/executor'
import type { BlockLog, ExecutionResult } from '@/executor/types'
import { Serializer } from '@/serializer'
@@ -453,18 +454,21 @@ export async function executeWorkflowForChat(
{} as Record<string, Record<string, any>>
)
// Get user environment variables for this workflow
// Get user environment variables with workspace precedence
let envVars: Record<string, string> = {}
try {
const envResult = await db
.select()
.from(envTable)
.where(eq(envTable.userId, deployment.userId))
const wfWorkspaceRow = await db
.select({ workspaceId: workflow.workspaceId })
.from(workflow)
.where(eq(workflow.id, workflowId))
.limit(1)
if (envResult.length > 0 && envResult[0].variables) {
envVars = envResult[0].variables as Record<string, string>
}
const workspaceId = wfWorkspaceRow[0]?.workspaceId || undefined
const { personalEncrypted, workspaceEncrypted } = await getPersonalAndWorkspaceEnv(
deployment.userId,
workspaceId
)
envVars = { ...personalEncrypted, ...workspaceEncrypted }
} catch (error) {
logger.warn(`[${requestId}] Could not fetch environment variables:`, error)
}

View File

@@ -1,34 +1,12 @@
import { createCipheriv, createHash, createHmac, randomBytes } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { generateApiKey } from '@/lib/utils'
import { db } from '@/db'
import { copilotApiKeys } from '@/db/schema'
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
const logger = createLogger('CopilotApiKeysGenerate')
function deriveKey(keyString: string): Buffer {
return createHash('sha256').update(keyString, 'utf8').digest()
}
function encryptRandomIv(plaintext: string, keyString: string): string {
const key = deriveKey(keyString)
const iv = randomBytes(16)
const cipher = createCipheriv('aes-256-gcm', key, iv)
let encrypted = cipher.update(plaintext, 'utf8', 'hex')
encrypted += cipher.final('hex')
const authTag = cipher.getAuthTag().toString('hex')
return `${iv.toString('hex')}:${encrypted}:${authTag}`
}
function computeLookup(plaintext: string, keyString: string): string {
// Deterministic, constant-time comparable MAC: HMAC-SHA256(DB_KEY, plaintext)
return createHmac('sha256', Buffer.from(keyString, 'utf8'))
.update(plaintext, 'utf8')
.digest('hex')
}
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
export async function POST(req: NextRequest) {
try {
@@ -37,34 +15,39 @@ export async function POST(req: NextRequest) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
if (!env.AGENT_API_DB_ENCRYPTION_KEY) {
logger.error('AGENT_API_DB_ENCRYPTION_KEY is not set')
return NextResponse.json({ error: 'Server not configured' }, { status: 500 })
}
const userId = session.user.id
// Generate and prefix the key (strip the generic sim_ prefix from the random part)
const rawKey = generateApiKey().replace(/^sim_/, '')
const plaintextKey = `sk-sim-copilot-${rawKey}`
const res = await fetch(`${SIM_AGENT_API_URL}/api/validate-key/generate`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
},
body: JSON.stringify({ userId }),
})
// Encrypt with random IV for confidentiality
const dbEncrypted = encryptRandomIv(plaintextKey, env.AGENT_API_DB_ENCRYPTION_KEY)
if (!res.ok) {
const errorBody = await res.text().catch(() => '')
logger.error('Sim Agent generate key error', { status: res.status, error: errorBody })
return NextResponse.json(
{ error: 'Failed to generate copilot API key' },
{ status: res.status || 500 }
)
}
// Compute deterministic lookup value for O(1) search
const lookup = computeLookup(plaintextKey, env.AGENT_API_DB_ENCRYPTION_KEY)
const data = (await res.json().catch(() => null)) as { apiKey?: string } | null
const [inserted] = await db
.insert(copilotApiKeys)
.values({ userId, apiKeyEncrypted: dbEncrypted, apiKeyLookup: lookup })
.returning({ id: copilotApiKeys.id })
if (!data?.apiKey) {
logger.error('Sim Agent generate key returned invalid payload')
return NextResponse.json({ error: 'Invalid response from Sim Agent' }, { status: 500 })
}
return NextResponse.json(
{ success: true, key: { id: inserted.id, apiKey: plaintextKey } },
{ success: true, key: { id: 'new', apiKey: data.apiKey } },
{ status: 201 }
)
} catch (error) {
logger.error('Failed to generate copilot API key', { error })
logger.error('Failed to proxy generate copilot API key', { error })
return NextResponse.json({ error: 'Failed to generate copilot API key' }, { status: 500 })
}
}

View File

@@ -1,32 +1,12 @@
import { createDecipheriv, createHash } from 'crypto'
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { copilotApiKeys } from '@/db/schema'
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
const logger = createLogger('CopilotApiKeys')
function deriveKey(keyString: string): Buffer {
return createHash('sha256').update(keyString, 'utf8').digest()
}
function decryptWithKey(encryptedValue: string, keyString: string): string {
const parts = encryptedValue.split(':')
if (parts.length !== 3) {
throw new Error('Invalid encrypted value format')
}
const [ivHex, encryptedHex, authTagHex] = parts
const key = deriveKey(keyString)
const iv = Buffer.from(ivHex, 'hex')
const decipher = createDecipheriv('aes-256-gcm', key, iv)
decipher.setAuthTag(Buffer.from(authTagHex, 'hex'))
let decrypted = decipher.update(encryptedHex, 'hex', 'utf8')
decrypted += decipher.final('utf8')
return decrypted
}
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
export async function GET(request: NextRequest) {
try {
@@ -35,22 +15,31 @@ export async function GET(request: NextRequest) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
if (!env.AGENT_API_DB_ENCRYPTION_KEY) {
logger.error('AGENT_API_DB_ENCRYPTION_KEY is not set')
return NextResponse.json({ error: 'Server not configured' }, { status: 500 })
}
const userId = session.user.id
const rows = await db
.select({ id: copilotApiKeys.id, apiKeyEncrypted: copilotApiKeys.apiKeyEncrypted })
.from(copilotApiKeys)
.where(eq(copilotApiKeys.userId, userId))
const res = await fetch(`${SIM_AGENT_API_URL}/api/validate-key/get-api-keys`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
},
body: JSON.stringify({ userId }),
})
const keys = rows.map((row) => ({
id: row.id,
apiKey: decryptWithKey(row.apiKeyEncrypted, env.AGENT_API_DB_ENCRYPTION_KEY as string),
}))
if (!res.ok) {
const errorBody = await res.text().catch(() => '')
logger.error('Sim Agent get-api-keys error', { status: res.status, error: errorBody })
return NextResponse.json({ error: 'Failed to get keys' }, { status: res.status || 500 })
}
const apiKeys = (await res.json().catch(() => null)) as { id: string; apiKey: string }[] | null
if (!Array.isArray(apiKeys)) {
logger.error('Sim Agent get-api-keys returned invalid payload')
return NextResponse.json({ error: 'Invalid response from Sim Agent' }, { status: 500 })
}
const keys = apiKeys
return NextResponse.json({ keys }, { status: 200 })
} catch (error) {
@@ -73,9 +62,26 @@ export async function DELETE(request: NextRequest) {
return NextResponse.json({ error: 'id is required' }, { status: 400 })
}
await db
.delete(copilotApiKeys)
.where(and(eq(copilotApiKeys.userId, userId), eq(copilotApiKeys.id, id)))
const res = await fetch(`${SIM_AGENT_API_URL}/api/validate-key/delete`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
},
body: JSON.stringify({ userId, apiKeyId: id }),
})
if (!res.ok) {
const errorBody = await res.text().catch(() => '')
logger.error('Sim Agent delete key error', { status: res.status, error: errorBody })
return NextResponse.json({ error: 'Failed to delete key' }, { status: res.status || 500 })
}
const data = (await res.json().catch(() => null)) as { success?: boolean } | null
if (!data?.success) {
logger.error('Sim Agent delete key returned invalid payload')
return NextResponse.json({ error: 'Invalid response from Sim Agent' }, { status: 500 })
}
return NextResponse.json({ success: true }, { status: 200 })
} catch (error) {

View File

@@ -1,50 +1,29 @@
import { createHmac } from 'crypto'
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { env } from '@/lib/env'
import { checkInternalApiKey } from '@/lib/copilot/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { copilotApiKeys, userStats } from '@/db/schema'
import { userStats } from '@/db/schema'
const logger = createLogger('CopilotApiKeysValidate')
function computeLookup(plaintext: string, keyString: string): string {
// Deterministic MAC: HMAC-SHA256(DB_KEY, plaintext)
return createHmac('sha256', Buffer.from(keyString, 'utf8'))
.update(plaintext, 'utf8')
.digest('hex')
}
export async function POST(req: NextRequest) {
try {
if (!env.AGENT_API_DB_ENCRYPTION_KEY) {
logger.error('AGENT_API_DB_ENCRYPTION_KEY is not set')
return NextResponse.json({ error: 'Server not configured' }, { status: 500 })
// Authenticate via internal API key header
const auth = checkInternalApiKey(req)
if (!auth.success) {
return new NextResponse(null, { status: 401 })
}
const body = await req.json().catch(() => null)
const apiKey = typeof body?.apiKey === 'string' ? body.apiKey : undefined
const userId = typeof body?.userId === 'string' ? body.userId : undefined
if (!apiKey) {
return new NextResponse(null, { status: 401 })
if (!userId) {
return NextResponse.json({ error: 'userId is required' }, { status: 400 })
}
const lookup = computeLookup(apiKey, env.AGENT_API_DB_ENCRYPTION_KEY)
logger.info('[API VALIDATION] Validating usage limit', { userId })
// Find matching API key and its user
const rows = await db
.select({ id: copilotApiKeys.id, userId: copilotApiKeys.userId })
.from(copilotApiKeys)
.where(eq(copilotApiKeys.apiKeyLookup, lookup))
.limit(1)
if (rows.length === 0) {
return new NextResponse(null, { status: 401 })
}
const { userId } = rows[0]
// Check usage for the associated user
const usage = await db
.select({
currentPeriodCost: userStats.currentPeriodCost,
@@ -55,6 +34,8 @@ export async function POST(req: NextRequest) {
.where(eq(userStats.userId, userId))
.limit(1)
logger.info('[API VALIDATION] Usage limit validated', { userId, usage })
if (usage.length > 0) {
const currentUsage = Number.parseFloat(
(usage[0].currentPeriodCost?.toString() as string) ||
@@ -64,16 +45,14 @@ export async function POST(req: NextRequest) {
const limit = Number.parseFloat((usage[0].currentUsageLimit as unknown as string) || '0')
if (!Number.isNaN(limit) && limit > 0 && currentUsage >= limit) {
// Usage exceeded
logger.info('[API VALIDATION] Usage exceeded', { userId, currentUsage, limit })
return new NextResponse(null, { status: 402 })
}
}
// Valid and within usage limits
return new NextResponse(null, { status: 200 })
} catch (error) {
logger.error('Error validating copilot API key', { error })
return NextResponse.json({ error: 'Failed to validate key' }, { status: 500 })
logger.error('Error validating usage limit', { error })
return NextResponse.json({ error: 'Failed to validate usage' }, { status: 500 })
}
}

View File

@@ -224,9 +224,9 @@ describe('Copilot Chat API Route', () => {
stream: true,
streamToolCalls: true,
mode: 'agent',
provider: 'openai',
messageId: 'mock-uuid-1234-5678',
depth: 0,
origin: 'http://localhost:3000',
chatId: 'chat-123',
}),
})
)
@@ -288,9 +288,9 @@ describe('Copilot Chat API Route', () => {
stream: true,
streamToolCalls: true,
mode: 'agent',
provider: 'openai',
messageId: 'mock-uuid-1234-5678',
depth: 0,
origin: 'http://localhost:3000',
chatId: 'chat-123',
}),
})
)
@@ -300,7 +300,6 @@ describe('Copilot Chat API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock new chat creation
const newChat = {
id: 'chat-123',
userId: 'user-123',
@@ -309,8 +308,6 @@ describe('Copilot Chat API Route', () => {
}
mockReturning.mockResolvedValue([newChat])
// Mock sim agent response
;(global.fetch as any).mockResolvedValue({
ok: true,
body: new ReadableStream({
@@ -344,9 +341,9 @@ describe('Copilot Chat API Route', () => {
stream: true,
streamToolCalls: true,
mode: 'agent',
provider: 'openai',
messageId: 'mock-uuid-1234-5678',
depth: 0,
origin: 'http://localhost:3000',
chatId: 'chat-123',
}),
})
)
@@ -356,11 +353,8 @@ describe('Copilot Chat API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock new chat creation
mockReturning.mockResolvedValue([{ id: 'chat-123', messages: [] }])
// Mock sim agent error
;(global.fetch as any).mockResolvedValue({
ok: false,
status: 500,
@@ -406,11 +400,8 @@ describe('Copilot Chat API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock new chat creation
mockReturning.mockResolvedValue([{ id: 'chat-123', messages: [] }])
// Mock sim agent response
;(global.fetch as any).mockResolvedValue({
ok: true,
body: new ReadableStream({
@@ -440,9 +431,9 @@ describe('Copilot Chat API Route', () => {
stream: true,
streamToolCalls: true,
mode: 'ask',
provider: 'openai',
messageId: 'mock-uuid-1234-5678',
depth: 0,
origin: 'http://localhost:3000',
chatId: 'chat-123',
}),
})
)

View File

@@ -1,4 +1,3 @@
import { createCipheriv, createDecipheriv, createHash, randomBytes } from 'crypto'
import { and, desc, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
@@ -11,77 +10,36 @@ import {
createUnauthorizedResponse,
} from '@/lib/copilot/auth'
import { getCopilotModel } from '@/lib/copilot/config'
import { TITLE_GENERATION_SYSTEM_PROMPT, TITLE_GENERATION_USER_PROMPT } from '@/lib/copilot/prompts'
import type { CopilotProviderConfig } from '@/lib/copilot/types'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
import { downloadFile } from '@/lib/uploads'
import { downloadFromS3WithConfig } from '@/lib/uploads/s3/s3-client'
import { S3_COPILOT_CONFIG, USE_S3_STORAGE } from '@/lib/uploads/setup'
import { generateChatTitle } from '@/lib/sim-agent/utils'
import { createFileContent, isSupportedFileType } from '@/lib/uploads/file-utils'
import { S3_COPILOT_CONFIG } from '@/lib/uploads/setup'
import { downloadFile, getStorageProvider } from '@/lib/uploads/storage-client'
import { db } from '@/db'
import { copilotChats } from '@/db/schema'
import { executeProviderRequest } from '@/providers'
import { createAnthropicFileContent, isSupportedFileType } from './file-utils'
const logger = createLogger('CopilotChatAPI')
// Sim Agent API configuration
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
function getRequestOrigin(_req: NextRequest): string {
try {
// Strictly use configured Better Auth URL
return env.BETTER_AUTH_URL || ''
} catch (_) {
return ''
}
}
function deriveKey(keyString: string): Buffer {
return createHash('sha256').update(keyString, 'utf8').digest()
}
function decryptWithKey(encryptedValue: string, keyString: string): string {
const [ivHex, encryptedHex, authTagHex] = encryptedValue.split(':')
if (!ivHex || !encryptedHex || !authTagHex) {
throw new Error('Invalid encrypted format')
}
const key = deriveKey(keyString)
const iv = Buffer.from(ivHex, 'hex')
const decipher = createDecipheriv('aes-256-gcm', key, iv)
decipher.setAuthTag(Buffer.from(authTagHex, 'hex'))
let decrypted = decipher.update(encryptedHex, 'hex', 'utf8')
decrypted += decipher.final('utf8')
return decrypted
}
function encryptWithKey(plaintext: string, keyString: string): string {
const key = deriveKey(keyString)
const iv = randomBytes(16)
const cipher = createCipheriv('aes-256-gcm', key, iv)
let encrypted = cipher.update(plaintext, 'utf8', 'hex')
encrypted += cipher.final('hex')
const authTag = cipher.getAuthTag().toString('hex')
return `${iv.toString('hex')}:${encrypted}:${authTag}`
}
// Schema for file attachments
const FileAttachmentSchema = z.object({
id: z.string(),
s3_key: z.string(),
key: z.string(),
filename: z.string(),
media_type: z.string(),
size: z.number(),
})
// Schema for chat messages
const ChatMessageSchema = z.object({
message: z.string().min(1, 'Message is required'),
userMessageId: z.string().optional(), // ID from frontend for the user message
chatId: z.string().optional(),
workflowId: z.string().min(1, 'Workflow ID is required'),
mode: z.enum(['ask', 'agent']).optional().default('agent'),
depth: z.number().int().min(-2).max(3).optional().default(0),
depth: z.number().int().min(0).max(3).optional().default(0),
prefetch: z.boolean().optional(),
createNewChat: z.boolean().optional().default(false),
stream: z.boolean().optional().default(true),
@@ -89,90 +47,32 @@ const ChatMessageSchema = z.object({
fileAttachments: z.array(FileAttachmentSchema).optional(),
provider: z.string().optional().default('openai'),
conversationId: z.string().optional(),
})
/**
* Generate a chat title using LLM
*/
async function generateChatTitle(userMessage: string): Promise<string> {
try {
const { provider, model } = getCopilotModel('title')
// Get the appropriate API key for the provider
let apiKey: string | undefined
if (provider === 'anthropic') {
// Use rotating API key for Anthropic
const { getRotatingApiKey } = require('@/lib/utils')
try {
apiKey = getRotatingApiKey('anthropic')
logger.debug(`Using rotating API key for Anthropic title generation`)
} catch (e) {
// If rotation fails, let the provider handle it
logger.warn(`Failed to get rotating API key for Anthropic:`, e)
}
}
const response = await executeProviderRequest(provider, {
model,
systemPrompt: TITLE_GENERATION_SYSTEM_PROMPT,
context: TITLE_GENERATION_USER_PROMPT(userMessage),
temperature: 0.3,
maxTokens: 50,
apiKey: apiKey || '',
stream: false,
})
if (typeof response === 'object' && 'content' in response) {
return response.content?.trim() || 'New Chat'
}
return 'New Chat'
} catch (error) {
logger.error('Failed to generate chat title:', error)
return 'New Chat'
}
}
/**
* Generate chat title asynchronously and update the database
*/
async function generateChatTitleAsync(
chatId: string,
userMessage: string,
requestId: string,
streamController?: ReadableStreamDefaultController<Uint8Array>
): Promise<void> {
try {
// logger.info(`[${requestId}] Starting async title generation for chat ${chatId}`)
const title = await generateChatTitle(userMessage)
// Update the chat with the generated title
await db
.update(copilotChats)
.set({
title,
updatedAt: new Date(),
contexts: z
.array(
z.object({
kind: z.enum([
'past_chat',
'workflow',
'current_workflow',
'blocks',
'logs',
'workflow_block',
'knowledge',
'templates',
'docs',
]),
label: z.string(),
chatId: z.string().optional(),
workflowId: z.string().optional(),
knowledgeId: z.string().optional(),
blockId: z.string().optional(),
templateId: z.string().optional(),
executionId: z.string().optional(),
// For workflow_block, provide both workflowId and blockId
})
.where(eq(copilotChats.id, chatId))
// Send title_updated event to client if streaming
if (streamController) {
const encoder = new TextEncoder()
const titleEvent = `data: ${JSON.stringify({
type: 'title_updated',
title: title,
})}\n\n`
streamController.enqueue(encoder.encode(titleEvent))
logger.debug(`[${requestId}] Sent title_updated event to client: "${title}"`)
}
// logger.info(`[${requestId}] Generated title for chat ${chatId}: "${title}"`)
} catch (error) {
logger.error(`[${requestId}] Failed to generate title for chat ${chatId}:`, error)
// Don't throw - this is a background operation
}
}
)
.optional(),
})
/**
* POST /api/copilot/chat
@@ -206,14 +106,45 @@ export async function POST(req: NextRequest) {
fileAttachments,
provider,
conversationId,
contexts,
} = ChatMessageSchema.parse(body)
// Derive request origin for downstream service
const requestOrigin = getRequestOrigin(req)
if (!requestOrigin) {
logger.error(`[${tracker.requestId}] Missing required configuration: BETTER_AUTH_URL`)
return createInternalServerErrorResponse('Missing required configuration: BETTER_AUTH_URL')
// Ensure we have a consistent user message ID for this request
const userMessageIdToUse = userMessageId || crypto.randomUUID()
try {
logger.info(`[${tracker.requestId}] Received chat POST`, {
hasContexts: Array.isArray(contexts),
contextsCount: Array.isArray(contexts) ? contexts.length : 0,
contextsPreview: Array.isArray(contexts)
? contexts.map((c: any) => ({
kind: c?.kind,
chatId: c?.chatId,
workflowId: c?.workflowId,
executionId: (c as any)?.executionId,
label: c?.label,
}))
: undefined,
})
} catch {}
// Preprocess contexts server-side
let agentContexts: Array<{ type: string; content: string }> = []
if (Array.isArray(contexts) && contexts.length > 0) {
try {
const { processContextsServer } = await import('@/lib/copilot/process-contents')
const processed = await processContextsServer(contexts as any, authenticatedUserId, message)
agentContexts = processed
logger.info(`[${tracker.requestId}] Contexts processed for request`, {
processedCount: agentContexts.length,
kinds: agentContexts.map((c) => c.type),
lengthPreview: agentContexts.map((c) => c.content?.length ?? 0),
})
if (Array.isArray(contexts) && contexts.length > 0 && agentContexts.length === 0) {
logger.warn(
`[${tracker.requestId}] Contexts provided but none processed. Check executionId for logs contexts.`
)
}
} catch (e) {
logger.error(`[${tracker.requestId}] Failed to process contexts`, e)
}
}
// Consolidation mapping: map negative depths to base depth with prefetch=true
@@ -229,22 +160,6 @@ export async function POST(req: NextRequest) {
}
}
// logger.info(`[${tracker.requestId}] Processing copilot chat request`, {
// userId: authenticatedUserId,
// workflowId,
// chatId,
// mode,
// stream,
// createNewChat,
// messageLength: message.length,
// hasImplicitFeedback: !!implicitFeedback,
// provider: provider || 'openai',
// hasConversationId: !!conversationId,
// depth,
// prefetch,
// origin: requestOrigin,
// })
// Handle chat context
let currentChat: any = null
let conversationHistory: any[] = []
@@ -285,8 +200,6 @@ export async function POST(req: NextRequest) {
// Process file attachments if present
const processedFileContents: any[] = []
if (fileAttachments && fileAttachments.length > 0) {
// logger.info(`[${tracker.requestId}] Processing ${fileAttachments.length} file attachments`)
for (const attachment of fileAttachments) {
try {
// Check if file type is supported
@@ -295,23 +208,30 @@ export async function POST(req: NextRequest) {
continue
}
// Download file from S3
// logger.info(`[${tracker.requestId}] Downloading file: ${attachment.s3_key}`)
const storageProvider = getStorageProvider()
let fileBuffer: Buffer
if (USE_S3_STORAGE) {
fileBuffer = await downloadFromS3WithConfig(attachment.s3_key, S3_COPILOT_CONFIG)
if (storageProvider === 's3') {
fileBuffer = await downloadFile(attachment.key, {
bucket: S3_COPILOT_CONFIG.bucket,
region: S3_COPILOT_CONFIG.region,
})
} else if (storageProvider === 'blob') {
const { BLOB_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
fileBuffer = await downloadFile(attachment.key, {
containerName: BLOB_COPILOT_CONFIG.containerName,
accountName: BLOB_COPILOT_CONFIG.accountName,
accountKey: BLOB_COPILOT_CONFIG.accountKey,
connectionString: BLOB_COPILOT_CONFIG.connectionString,
})
} else {
// Fallback to generic downloadFile for other storage providers
fileBuffer = await downloadFile(attachment.s3_key)
fileBuffer = await downloadFile(attachment.key)
}
// Convert to Anthropic format
const fileContent = createAnthropicFileContent(fileBuffer, attachment.media_type)
// Convert to format
const fileContent = createFileContent(fileBuffer, attachment.media_type)
if (fileContent) {
processedFileContents.push(fileContent)
// logger.info(
// `[${tracker.requestId}] Processed file: ${attachment.filename} (${attachment.media_type})`
// )
}
} catch (error) {
logger.error(
@@ -336,14 +256,26 @@ export async function POST(req: NextRequest) {
for (const attachment of msg.fileAttachments) {
try {
if (isSupportedFileType(attachment.media_type)) {
const storageProvider = getStorageProvider()
let fileBuffer: Buffer
if (USE_S3_STORAGE) {
fileBuffer = await downloadFromS3WithConfig(attachment.s3_key, S3_COPILOT_CONFIG)
if (storageProvider === 's3') {
fileBuffer = await downloadFile(attachment.key, {
bucket: S3_COPILOT_CONFIG.bucket,
region: S3_COPILOT_CONFIG.region,
})
} else if (storageProvider === 'blob') {
const { BLOB_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
fileBuffer = await downloadFile(attachment.key, {
containerName: BLOB_COPILOT_CONFIG.containerName,
accountName: BLOB_COPILOT_CONFIG.accountName,
accountKey: BLOB_COPILOT_CONFIG.accountKey,
connectionString: BLOB_COPILOT_CONFIG.connectionString,
})
} else {
// Fallback to generic downloadFile for other storage providers
fileBuffer = await downloadFile(attachment.s3_key)
fileBuffer = await downloadFile(attachment.key)
}
const fileContent = createAnthropicFileContent(fileBuffer, attachment.media_type)
const fileContent = createFileContent(fileBuffer, attachment.media_type)
if (fileContent) {
content.push(fileContent)
}
@@ -399,8 +331,31 @@ export async function POST(req: NextRequest) {
})
}
const defaults = getCopilotModel('chat')
const modelToUse = env.COPILOT_MODEL || defaults.model
let providerConfig: CopilotProviderConfig | undefined
const providerEnv = env.COPILOT_PROVIDER as any
if (providerEnv) {
if (providerEnv === 'azure-openai') {
providerConfig = {
provider: 'azure-openai',
model: modelToUse,
apiKey: env.AZURE_OPENAI_API_KEY,
apiVersion: 'preview',
endpoint: env.AZURE_OPENAI_ENDPOINT,
}
} else {
providerConfig = {
provider: providerEnv,
model: modelToUse,
apiKey: env.COPILOT_API_KEY,
}
}
}
// Determine provider and conversationId to use for this request
const providerToUse = provider || 'openai'
const effectiveConversationId =
(currentChat?.conversationId as string | undefined) || conversationId
@@ -416,15 +371,21 @@ export async function POST(req: NextRequest) {
stream: stream,
streamToolCalls: true,
mode: mode,
provider: providerToUse,
messageId: userMessageIdToUse,
...(providerConfig ? { provider: providerConfig } : {}),
...(effectiveConversationId ? { conversationId: effectiveConversationId } : {}),
...(typeof effectiveDepth === 'number' ? { depth: effectiveDepth } : {}),
...(typeof effectivePrefetch === 'boolean' ? { prefetch: effectivePrefetch } : {}),
...(session?.user?.name && { userName: session.user.name }),
...(requestOrigin ? { origin: requestOrigin } : {}),
...(agentContexts.length > 0 && { context: agentContexts }),
...(actualChatId ? { chatId: actualChatId } : {}),
}
// Log the payload being sent to the streaming endpoint (logs currently disabled)
try {
logger.info(`[${tracker.requestId}] About to call Sim Agent with context`, {
context: (requestPayload as any).context,
})
} catch {}
const simAgentResponse = await fetch(`${SIM_AGENT_API_URL}/api/chat-completion-streaming`, {
method: 'POST',
@@ -455,15 +416,18 @@ export async function POST(req: NextRequest) {
// If streaming is requested, forward the stream and update chat later
if (stream && simAgentResponse.body) {
// logger.info(`[${tracker.requestId}] Streaming response from sim agent`)
// Create user message to save
const userMessage = {
id: userMessageId || crypto.randomUUID(), // Use frontend ID if provided
id: userMessageIdToUse, // Consistent ID used for request and persistence
role: 'user',
content: message,
timestamp: new Date().toISOString(),
...(fileAttachments && fileAttachments.length > 0 && { fileAttachments }),
...(Array.isArray(contexts) && contexts.length > 0 && { contexts }),
...(Array.isArray(contexts) &&
contexts.length > 0 && {
contentBlocks: [{ type: 'contexts', contexts: contexts as any, timestamp: Date.now() }],
}),
}
// Create a pass-through stream that captures the response
@@ -495,30 +459,30 @@ export async function POST(req: NextRequest) {
// Start title generation in parallel if needed
if (actualChatId && !currentChat?.title && conversationHistory.length === 0) {
// logger.info(`[${tracker.requestId}] Starting title generation with stream updates`, {
// chatId: actualChatId,
// hasTitle: !!currentChat?.title,
// conversationLength: conversationHistory.length,
// message: message.substring(0, 100) + (message.length > 100 ? '...' : ''),
// })
generateChatTitleAsync(actualChatId, message, tracker.requestId, controller).catch(
(error) => {
generateChatTitle(message)
.then(async (title) => {
if (title) {
await db
.update(copilotChats)
.set({
title,
updatedAt: new Date(),
})
.where(eq(copilotChats.id, actualChatId!))
const titleEvent = `data: ${JSON.stringify({
type: 'title_updated',
title: title,
})}\n\n`
controller.enqueue(encoder.encode(titleEvent))
logger.info(`[${tracker.requestId}] Generated and saved title: ${title}`)
}
})
.catch((error) => {
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
}
)
})
} else {
// logger.debug(`[${tracker.requestId}] Skipping title generation`, {
// chatId: actualChatId,
// hasTitle: !!currentChat?.title,
// conversationLength: conversationHistory.length,
// reason: !actualChatId
// ? 'no chatId'
// : currentChat?.title
// ? 'already has title'
// : conversationHistory.length > 0
// ? 'not first message'
// : 'unknown',
// })
logger.debug(`[${tracker.requestId}] Skipping title generation`)
}
// Forward the sim agent stream and capture assistant response
@@ -529,24 +493,9 @@ export async function POST(req: NextRequest) {
while (true) {
const { done, value } = await reader.read()
if (done) {
// logger.info(`[${tracker.requestId}] Stream reading completed`)
break
}
// Check if client disconnected before processing chunk
try {
// Forward the chunk to client immediately
controller.enqueue(value)
} catch (error) {
// Client disconnected - stop reading from sim agent
// logger.info(
// `[${tracker.requestId}] Client disconnected, stopping stream processing`
// )
reader.cancel() // Stop reading from sim agent
break
}
const chunkSize = value.byteLength
// Decode and parse SSE events for logging and capturing content
const decodedChunk = decoder.decode(value, { stream: true })
buffer += decodedChunk
@@ -581,22 +530,12 @@ export async function POST(req: NextRequest) {
break
case 'reasoning':
// Treat like thinking: do not add to assistantContent to avoid leaking
logger.debug(
`[${tracker.requestId}] Reasoning chunk received (${(event.data || event.content || '').length} chars)`
)
break
case 'tool_call':
// logger.info(
// `[${tracker.requestId}] Tool call ${event.data?.partial ? '(partial)' : '(complete)'}:`,
// {
// id: event.data?.id,
// name: event.data?.name,
// arguments: event.data?.arguments,
// blockIndex: event.data?._blockIndex,
// }
// )
if (!event.data?.partial) {
toolCalls.push(event.data)
if (event.data?.id) {
@@ -606,23 +545,12 @@ export async function POST(req: NextRequest) {
break
case 'tool_generating':
// logger.info(`[${tracker.requestId}] Tool generating:`, {
// toolCallId: event.toolCallId,
// toolName: event.toolName,
// })
if (event.toolCallId) {
startedToolExecutionIds.add(event.toolCallId)
}
break
case 'tool_result':
// logger.info(`[${tracker.requestId}] Tool result received:`, {
// toolCallId: event.toolCallId,
// toolName: event.toolName,
// success: event.success,
// result: `${JSON.stringify(event.result).substring(0, 200)}...`,
// resultSize: JSON.stringify(event.result).length,
// })
if (event.toolCallId) {
completedToolExecutionIds.add(event.toolCallId)
}
@@ -667,6 +595,47 @@ export async function POST(req: NextRequest) {
default:
}
// Emit to client: rewrite 'error' events into user-friendly assistant message
if (event?.type === 'error') {
try {
const displayMessage: string =
(event?.data && (event.data.displayMessage as string)) ||
'Sorry, I encountered an error. Please try again.'
const formatted = `_${displayMessage}_`
// Accumulate so it persists to DB as assistant content
assistantContent += formatted
// Send as content chunk
try {
controller.enqueue(
encoder.encode(
`data: ${JSON.stringify({ type: 'content', data: formatted })}\n\n`
)
)
} catch (enqueueErr) {
reader.cancel()
break
}
// Then close this response cleanly for the client
try {
controller.enqueue(
encoder.encode(`data: ${JSON.stringify({ type: 'done' })}\n\n`)
)
} catch (enqueueErr) {
reader.cancel()
break
}
} catch {}
// Do not forward the original error event
} else {
// Forward original event to client
try {
controller.enqueue(encoder.encode(`data: ${jsonStr}\n\n`))
} catch (enqueueErr) {
reader.cancel()
break
}
}
} catch (e) {
// Enhanced error handling for large payloads and parsing issues
const lineLength = line.length
@@ -699,10 +668,37 @@ export async function POST(req: NextRequest) {
logger.debug(`[${tracker.requestId}] Processing remaining buffer: "${buffer}"`)
if (buffer.startsWith('data: ')) {
try {
const event = JSON.parse(buffer.slice(6))
const jsonStr = buffer.slice(6)
const event = JSON.parse(jsonStr)
if (event.type === 'content' && event.data) {
assistantContent += event.data
}
// Forward remaining event, applying same error rewrite behavior
if (event?.type === 'error') {
const displayMessage: string =
(event?.data && (event.data.displayMessage as string)) ||
'Sorry, I encountered an error. Please try again.'
const formatted = `_${displayMessage}_`
assistantContent += formatted
try {
controller.enqueue(
encoder.encode(
`data: ${JSON.stringify({ type: 'content', data: formatted })}\n\n`
)
)
controller.enqueue(
encoder.encode(`data: ${JSON.stringify({ type: 'done' })}\n\n`)
)
} catch (enqueueErr) {
reader.cancel()
}
} else {
try {
controller.enqueue(encoder.encode(`data: ${jsonStr}\n\n`))
} catch (enqueueErr) {
reader.cancel()
}
}
} catch (e) {
logger.warn(`[${tracker.requestId}] Failed to parse final buffer: "${buffer}"`)
}
@@ -818,11 +814,16 @@ export async function POST(req: NextRequest) {
// Save messages if we have a chat
if (currentChat && responseData.content) {
const userMessage = {
id: userMessageId || crypto.randomUUID(), // Use frontend ID if provided
id: userMessageIdToUse, // Consistent ID used for request and persistence
role: 'user',
content: message,
timestamp: new Date().toISOString(),
...(fileAttachments && fileAttachments.length > 0 && { fileAttachments }),
...(Array.isArray(contexts) && contexts.length > 0 && { contexts }),
...(Array.isArray(contexts) &&
contexts.length > 0 && {
contentBlocks: [{ type: 'contexts', contexts: contexts as any, timestamp: Date.now() }],
}),
}
const assistantMessage = {
@@ -837,9 +838,22 @@ export async function POST(req: NextRequest) {
// Start title generation in parallel if this is first message (non-streaming)
if (actualChatId && !currentChat.title && conversationHistory.length === 0) {
logger.info(`[${tracker.requestId}] Starting title generation for non-streaming response`)
generateChatTitleAsync(actualChatId, message, tracker.requestId).catch((error) => {
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
})
generateChatTitle(message)
.then(async (title) => {
if (title) {
await db
.update(copilotChats)
.set({
title,
updatedAt: new Date(),
})
.where(eq(copilotChats.id, actualChatId!))
logger.info(`[${tracker.requestId}] Generated and saved title: ${title}`)
}
})
.catch((error) => {
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
})
}
// Update chat in database immediately (without blocking for title)

View File

@@ -229,7 +229,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists - override the default empty array
const existingChat = {
id: 'chat-123',
userId: 'user-123',
@@ -267,7 +266,6 @@ describe('Copilot Chat Update Messages API Route', () => {
messageCount: 2,
})
// Verify database operations
expect(mockSelect).toHaveBeenCalled()
expect(mockUpdate).toHaveBeenCalled()
expect(mockSet).toHaveBeenCalledWith({
@@ -280,7 +278,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists
const existingChat = {
id: 'chat-456',
userId: 'user-123',
@@ -341,7 +338,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists
const existingChat = {
id: 'chat-789',
userId: 'user-123',
@@ -374,7 +370,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock database error during chat lookup
mockLimit.mockRejectedValueOnce(new Error('Database connection failed'))
const req = createMockRequest('POST', {
@@ -401,7 +396,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists
const existingChat = {
id: 'chat-123',
userId: 'user-123',
@@ -409,7 +403,6 @@ describe('Copilot Chat Update Messages API Route', () => {
}
mockLimit.mockResolvedValueOnce([existingChat])
// Mock database error during update
mockSet.mockReturnValueOnce({
where: vi.fn().mockRejectedValue(new Error('Update operation failed')),
})
@@ -438,7 +431,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Create a request with invalid JSON
const req = new NextRequest('http://localhost:3000/api/copilot/chat/update-messages', {
method: 'POST',
body: '{invalid-json',
@@ -459,7 +451,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists
const existingChat = {
id: 'chat-large',
userId: 'user-123',
@@ -467,7 +458,6 @@ describe('Copilot Chat Update Messages API Route', () => {
}
mockLimit.mockResolvedValueOnce([existingChat])
// Create a large array of messages
const messages = Array.from({ length: 100 }, (_, i) => ({
id: `msg-${i + 1}`,
role: i % 2 === 0 ? 'user' : 'assistant',
@@ -500,7 +490,6 @@ describe('Copilot Chat Update Messages API Route', () => {
const authMocks = mockAuth()
authMocks.setAuthenticated()
// Mock chat exists
const existingChat = {
id: 'chat-mixed',
userId: 'user-123',

View File

@@ -28,7 +28,7 @@ const UpdateMessagesSchema = z.object({
.array(
z.object({
id: z.string(),
s3_key: z.string(),
key: z.string(),
filename: z.string(),
media_type: z.string(),
size: z.number(),

View File

@@ -0,0 +1,39 @@
import { desc, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import {
authenticateCopilotRequestSessionOnly,
createInternalServerErrorResponse,
createUnauthorizedResponse,
} from '@/lib/copilot/auth'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { copilotChats } from '@/db/schema'
const logger = createLogger('CopilotChatsListAPI')
export async function GET(_req: NextRequest) {
try {
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
if (!isAuthenticated || !userId) {
return createUnauthorizedResponse()
}
const chats = await db
.select({
id: copilotChats.id,
title: copilotChats.title,
workflowId: copilotChats.workflowId,
updatedAt: copilotChats.updatedAt,
})
.from(copilotChats)
.where(eq(copilotChats.userId, userId))
.orderBy(desc(copilotChats.updatedAt))
logger.info(`Retrieved ${chats.length} chats for user ${userId}`)
return NextResponse.json({ success: true, chats })
} catch (error) {
logger.error('Error fetching user copilot chats:', error)
return createInternalServerErrorResponse('Failed to fetch user chats')
}
}

View File

@@ -0,0 +1,68 @@
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import {
authenticateCopilotRequestSessionOnly,
createBadRequestResponse,
createInternalServerErrorResponse,
createRequestTracker,
createUnauthorizedResponse,
} from '@/lib/copilot/auth'
import { env } from '@/lib/env'
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
const BodySchema = z.object({
messageId: z.string(),
diffCreated: z.boolean(),
diffAccepted: z.boolean(),
})
export async function POST(req: NextRequest) {
const tracker = createRequestTracker()
try {
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
if (!isAuthenticated || !userId) {
return createUnauthorizedResponse()
}
const json = await req.json().catch(() => ({}))
const parsed = BodySchema.safeParse(json)
if (!parsed.success) {
return createBadRequestResponse('Invalid request body for copilot stats')
}
const { messageId, diffCreated, diffAccepted } = parsed.data as any
// Build outgoing payload for Sim Agent with only required fields
const payload: Record<string, any> = {
messageId,
diffCreated,
diffAccepted,
}
const agentRes = await fetch(`${SIM_AGENT_API_URL}/api/stats`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
},
body: JSON.stringify(payload),
})
// Prefer not to block clients; still relay status
let agentJson: any = null
try {
agentJson = await agentRes.json()
} catch {}
if (!agentRes.ok) {
const message = (agentJson && (agentJson.error || agentJson.message)) || 'Upstream error'
return NextResponse.json({ success: false, error: message }, { status: 400 })
}
return NextResponse.json({ success: true })
} catch (error) {
return createInternalServerErrorResponse('Failed to forward copilot stats')
}
}

View File

@@ -10,7 +10,6 @@ import type { EnvironmentVariable } from '@/stores/settings/environment/types'
const logger = createLogger('EnvironmentAPI')
// Schema for environment variable updates
const EnvVarSchema = z.object({
variables: z.record(z.string()),
})
@@ -30,17 +29,13 @@ export async function POST(req: NextRequest) {
try {
const { variables } = EnvVarSchema.parse(body)
// Encrypt all variables
const encryptedVariables = await Object.entries(variables).reduce(
async (accPromise, [key, value]) => {
const acc = await accPromise
const encryptedVariables = await Promise.all(
Object.entries(variables).map(async ([key, value]) => {
const { encrypted } = await encryptSecret(value)
return { ...acc, [key]: encrypted }
},
Promise.resolve({})
)
return [key, encrypted] as const
})
).then((entries) => Object.fromEntries(entries))
// Replace all environment variables for user
await db
.insert(environment)
.values({
@@ -80,7 +75,6 @@ export async function GET(request: Request) {
const requestId = crypto.randomUUID().slice(0, 8)
try {
// Get the session directly in the API route
const session = await getSession()
if (!session?.user?.id) {
logger.warn(`[${requestId}] Unauthorized environment variables access attempt`)
@@ -99,18 +93,15 @@ export async function GET(request: Request) {
return NextResponse.json({ data: {} }, { status: 200 })
}
// Decrypt the variables for client-side use
const encryptedVariables = result[0].variables as Record<string, string>
const decryptedVariables: Record<string, EnvironmentVariable> = {}
// Decrypt each variable
for (const [key, encryptedValue] of Object.entries(encryptedVariables)) {
try {
const { decrypted } = await decryptSecret(encryptedValue)
decryptedVariables[key] = { key, value: decrypted }
} catch (error) {
logger.error(`[${requestId}] Error decrypting variable ${key}`, error)
// If decryption fails, provide a placeholder
decryptedVariables[key] = { key, value: '' }
}
}

View File

@@ -1,223 +0,0 @@
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getEnvironmentVariableKeys } from '@/lib/environment/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { decryptSecret, encryptSecret } from '@/lib/utils'
import { getUserId } from '@/app/api/auth/oauth/utils'
import { db } from '@/db'
import { environment } from '@/db/schema'
const logger = createLogger('EnvironmentVariablesAPI')
// Schema for environment variable updates
const EnvVarSchema = z.object({
variables: z.record(z.string()),
})
export async function GET(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
try {
// For GET requests, check for workflowId in query params
const { searchParams } = new URL(request.url)
const workflowId = searchParams.get('workflowId')
// Use dual authentication pattern like other copilot tools
const userId = await getUserId(requestId, workflowId || undefined)
if (!userId) {
logger.warn(`[${requestId}] Unauthorized environment variables access attempt`)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Get only the variable names (keys), not values
const result = await getEnvironmentVariableKeys(userId)
return NextResponse.json(
{
success: true,
output: result,
},
{ status: 200 }
)
} catch (error: any) {
logger.error(`[${requestId}] Environment variables fetch error`, error)
return NextResponse.json(
{
success: false,
error: error.message || 'Failed to get environment variables',
},
{ status: 500 }
)
}
}
export async function PUT(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
try {
const body = await request.json()
const { workflowId, variables } = body
// Use dual authentication pattern like other copilot tools
const userId = await getUserId(requestId, workflowId)
if (!userId) {
logger.warn(`[${requestId}] Unauthorized environment variables set attempt`)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
try {
const { variables: validatedVariables } = EnvVarSchema.parse({ variables })
// Get existing environment variables for this user
const existingData = await db
.select()
.from(environment)
.where(eq(environment.userId, userId))
.limit(1)
// Start with existing encrypted variables or empty object
const existingEncryptedVariables =
(existingData[0]?.variables as Record<string, string>) || {}
// Determine which variables are new or changed by comparing with decrypted existing values
const variablesToEncrypt: Record<string, string> = {}
const addedVariables: string[] = []
const updatedVariables: string[] = []
for (const [key, newValue] of Object.entries(validatedVariables)) {
if (!(key in existingEncryptedVariables)) {
// New variable
variablesToEncrypt[key] = newValue
addedVariables.push(key)
} else {
// Check if the value has actually changed by decrypting the existing value
try {
const { decrypted: existingValue } = await decryptSecret(
existingEncryptedVariables[key]
)
if (existingValue !== newValue) {
// Value changed, needs re-encryption
variablesToEncrypt[key] = newValue
updatedVariables.push(key)
}
// If values are the same, keep the existing encrypted value
} catch (decryptError) {
// If we can't decrypt the existing value, treat as changed and re-encrypt
logger.warn(
`[${requestId}] Could not decrypt existing variable ${key}, re-encrypting`,
{ error: decryptError }
)
variablesToEncrypt[key] = newValue
updatedVariables.push(key)
}
}
}
// Only encrypt the variables that are new or changed
const newlyEncryptedVariables = await Object.entries(variablesToEncrypt).reduce(
async (accPromise, [key, value]) => {
const acc = await accPromise
const { encrypted } = await encryptSecret(value)
return { ...acc, [key]: encrypted }
},
Promise.resolve({})
)
// Merge existing encrypted variables with newly encrypted ones
const finalEncryptedVariables = { ...existingEncryptedVariables, ...newlyEncryptedVariables }
// Update or insert environment variables for user
await db
.insert(environment)
.values({
id: crypto.randomUUID(),
userId: userId,
variables: finalEncryptedVariables,
updatedAt: new Date(),
})
.onConflictDoUpdate({
target: [environment.userId],
set: {
variables: finalEncryptedVariables,
updatedAt: new Date(),
},
})
return NextResponse.json(
{
success: true,
output: {
message: `Successfully processed ${Object.keys(validatedVariables).length} environment variable(s): ${addedVariables.length} added, ${updatedVariables.length} updated`,
variableCount: Object.keys(validatedVariables).length,
variableNames: Object.keys(validatedVariables),
totalVariableCount: Object.keys(finalEncryptedVariables).length,
addedVariables,
updatedVariables,
},
},
{ status: 200 }
)
} catch (validationError) {
if (validationError instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid environment variables data`, {
errors: validationError.errors,
})
return NextResponse.json(
{ error: 'Invalid request data', details: validationError.errors },
{ status: 400 }
)
}
throw validationError
}
} catch (error: any) {
logger.error(`[${requestId}] Environment variables set error`, error)
return NextResponse.json(
{
success: false,
error: error.message || 'Failed to set environment variables',
},
{ status: 500 }
)
}
}
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
try {
const body = await request.json()
const { workflowId } = body
// Use dual authentication pattern like other copilot tools
const userId = await getUserId(requestId, workflowId)
if (!userId) {
logger.warn(`[${requestId}] Unauthorized environment variables access attempt`)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Get only the variable names (keys), not values
const result = await getEnvironmentVariableKeys(userId)
return NextResponse.json(
{
success: true,
output: result,
},
{ status: 200 }
)
} catch (error: any) {
logger.error(`[${requestId}] Environment variables fetch error`, error)
return NextResponse.json(
{
success: false,
error: error.message || 'Failed to get environment variables',
},
{ status: 500 }
)
}
}

View File

@@ -1,16 +1,8 @@
import {
AbortMultipartUploadCommand,
CompleteMultipartUploadCommand,
CreateMultipartUploadCommand,
UploadPartCommand,
} from '@aws-sdk/client-s3'
import { getSignedUrl } from '@aws-sdk/s3-request-presigner'
import { type NextRequest, NextResponse } from 'next/server'
import { v4 as uuidv4 } from 'uuid'
import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console/logger'
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
import { S3_KB_CONFIG } from '@/lib/uploads/setup'
import { BLOB_KB_CONFIG } from '@/lib/uploads/setup'
const logger = createLogger('MultipartUploadAPI')
@@ -26,15 +18,6 @@ interface GetPartUrlsRequest {
partNumbers: number[]
}
interface CompleteMultipartRequest {
uploadId: string
key: string
parts: Array<{
ETag: string
PartNumber: number
}>
}
export async function POST(request: NextRequest) {
try {
const session = await getSession()
@@ -44,106 +27,214 @@ export async function POST(request: NextRequest) {
const action = request.nextUrl.searchParams.get('action')
if (!isUsingCloudStorage() || getStorageProvider() !== 's3') {
if (!isUsingCloudStorage()) {
return NextResponse.json(
{ error: 'Multipart upload is only available with S3 storage' },
{ error: 'Multipart upload is only available with cloud storage (S3 or Azure Blob)' },
{ status: 400 }
)
}
const { getS3Client } = await import('@/lib/uploads/s3/s3-client')
const s3Client = getS3Client()
const storageProvider = getStorageProvider()
switch (action) {
case 'initiate': {
const data: InitiateMultipartRequest = await request.json()
const { fileName, contentType } = data
const { fileName, contentType, fileSize } = data
const safeFileName = fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
const uniqueKey = `kb/${uuidv4()}-${safeFileName}`
if (storageProvider === 's3') {
const { initiateS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
const command = new CreateMultipartUploadCommand({
Bucket: S3_KB_CONFIG.bucket,
Key: uniqueKey,
ContentType: contentType,
Metadata: {
originalName: fileName,
uploadedAt: new Date().toISOString(),
purpose: 'knowledge-base',
},
})
const result = await initiateS3MultipartUpload({
fileName,
contentType,
fileSize,
})
const response = await s3Client.send(command)
logger.info(`Initiated S3 multipart upload for ${fileName}: ${result.uploadId}`)
logger.info(`Initiated multipart upload for ${fileName}: ${response.UploadId}`)
return NextResponse.json({
uploadId: result.uploadId,
key: result.key,
})
}
if (storageProvider === 'blob') {
const { initiateMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
return NextResponse.json({
uploadId: response.UploadId,
key: uniqueKey,
})
const result = await initiateMultipartUpload({
fileName,
contentType,
fileSize,
customConfig: {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
},
})
logger.info(`Initiated Azure multipart upload for ${fileName}: ${result.uploadId}`)
return NextResponse.json({
uploadId: result.uploadId,
key: result.key,
})
}
return NextResponse.json(
{ error: `Unsupported storage provider: ${storageProvider}` },
{ status: 400 }
)
}
case 'get-part-urls': {
const data: GetPartUrlsRequest = await request.json()
const { uploadId, key, partNumbers } = data
const presignedUrls = await Promise.all(
partNumbers.map(async (partNumber) => {
const command = new UploadPartCommand({
Bucket: S3_KB_CONFIG.bucket,
Key: key,
PartNumber: partNumber,
UploadId: uploadId,
})
if (storageProvider === 's3') {
const { getS3MultipartPartUrls } = await import('@/lib/uploads/s3/s3-client')
const url = await getSignedUrl(s3Client, command, { expiresIn: 3600 })
return { partNumber, url }
const presignedUrls = await getS3MultipartPartUrls(key, uploadId, partNumbers)
return NextResponse.json({ presignedUrls })
}
if (storageProvider === 'blob') {
const { getMultipartPartUrls } = await import('@/lib/uploads/blob/blob-client')
const presignedUrls = await getMultipartPartUrls(key, uploadId, partNumbers, {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
})
)
return NextResponse.json({ presignedUrls })
return NextResponse.json({ presignedUrls })
}
return NextResponse.json(
{ error: `Unsupported storage provider: ${storageProvider}` },
{ status: 400 }
)
}
case 'complete': {
const data: CompleteMultipartRequest = await request.json()
const data = await request.json()
// Handle batch completion
if ('uploads' in data) {
const results = await Promise.all(
data.uploads.map(async (upload: any) => {
const { uploadId, key } = upload
if (storageProvider === 's3') {
const { completeS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
const parts = upload.parts // S3 format: { ETag, PartNumber }
const result = await completeS3MultipartUpload(key, uploadId, parts)
return {
success: true,
location: result.location,
path: result.path,
key: result.key,
}
}
if (storageProvider === 'blob') {
const { completeMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
const parts = upload.parts // Azure format: { blockId, partNumber }
const result = await completeMultipartUpload(key, uploadId, parts, {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
})
return {
success: true,
location: result.location,
path: result.path,
key: result.key,
}
}
throw new Error(`Unsupported storage provider: ${storageProvider}`)
})
)
logger.info(`Completed ${data.uploads.length} multipart uploads`)
return NextResponse.json({ results })
}
// Handle single completion
const { uploadId, key, parts } = data
const command = new CompleteMultipartUploadCommand({
Bucket: S3_KB_CONFIG.bucket,
Key: key,
UploadId: uploadId,
MultipartUpload: {
Parts: parts.sort((a, b) => a.PartNumber - b.PartNumber),
},
})
if (storageProvider === 's3') {
const { completeS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
const response = await s3Client.send(command)
const result = await completeS3MultipartUpload(key, uploadId, parts)
logger.info(`Completed multipart upload for key ${key}`)
logger.info(`Completed S3 multipart upload for key ${key}`)
const finalPath = `/api/files/serve/s3/${encodeURIComponent(key)}`
return NextResponse.json({
success: true,
location: result.location,
path: result.path,
key: result.key,
})
}
if (storageProvider === 'blob') {
const { completeMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
return NextResponse.json({
success: true,
location: response.Location,
path: finalPath,
key,
})
const result = await completeMultipartUpload(key, uploadId, parts, {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
})
logger.info(`Completed Azure multipart upload for key ${key}`)
return NextResponse.json({
success: true,
location: result.location,
path: result.path,
key: result.key,
})
}
return NextResponse.json(
{ error: `Unsupported storage provider: ${storageProvider}` },
{ status: 400 }
)
}
case 'abort': {
const data = await request.json()
const { uploadId, key } = data
const command = new AbortMultipartUploadCommand({
Bucket: S3_KB_CONFIG.bucket,
Key: key,
UploadId: uploadId,
})
if (storageProvider === 's3') {
const { abortS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
await s3Client.send(command)
await abortS3MultipartUpload(key, uploadId)
logger.info(`Aborted multipart upload for key ${key}`)
logger.info(`Aborted S3 multipart upload for key ${key}`)
} else if (storageProvider === 'blob') {
const { abortMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
await abortMultipartUpload(key, uploadId, {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
})
logger.info(`Aborted Azure multipart upload for key ${key}`)
} else {
return NextResponse.json(
{ error: `Unsupported storage provider: ${storageProvider}` },
{ status: 400 }
)
}
return NextResponse.json({ success: true })
}

View File

@@ -76,11 +76,9 @@ export async function POST(request: NextRequest) {
logger.info('File parse request received:', { filePath, fileType })
// Handle multiple files
if (Array.isArray(filePath)) {
const results = []
for (const path of filePath) {
// Skip empty or invalid paths
if (!path || (typeof path === 'string' && path.trim() === '')) {
results.push({
success: false,
@@ -91,12 +89,10 @@ export async function POST(request: NextRequest) {
}
const result = await parseFileSingle(path, fileType)
// Add processing time to metadata
if (result.metadata) {
result.metadata.processingTime = Date.now() - startTime
}
// Transform each result to match expected frontend format
if (result.success) {
results.push({
success: true,
@@ -105,7 +101,7 @@ export async function POST(request: NextRequest) {
name: result.filePath.split('/').pop() || 'unknown',
fileType: result.metadata?.fileType || 'application/octet-stream',
size: result.metadata?.size || 0,
binary: false, // We only return text content
binary: false,
},
filePath: result.filePath,
})
@@ -120,15 +116,12 @@ export async function POST(request: NextRequest) {
})
}
// Handle single file
const result = await parseFileSingle(filePath, fileType)
// Add processing time to metadata
if (result.metadata) {
result.metadata.processingTime = Date.now() - startTime
}
// Transform single file result to match expected frontend format
if (result.success) {
return NextResponse.json({
success: true,
@@ -142,8 +135,6 @@ export async function POST(request: NextRequest) {
})
}
// Only return 500 for actual server errors, not file processing failures
// File processing failures (like file not found, parsing errors) should return 200 with success:false
return NextResponse.json(result)
} catch (error) {
logger.error('Error in file parse API:', error)
@@ -164,7 +155,6 @@ export async function POST(request: NextRequest) {
async function parseFileSingle(filePath: string, fileType?: string): Promise<ParseResult> {
logger.info('Parsing file:', filePath)
// Validate that filePath is not empty
if (!filePath || filePath.trim() === '') {
return {
success: false,
@@ -173,7 +163,6 @@ async function parseFileSingle(filePath: string, fileType?: string): Promise<Par
}
}
// Validate path for security before any processing
const pathValidation = validateFilePath(filePath)
if (!pathValidation.isValid) {
return {
@@ -183,49 +172,40 @@ async function parseFileSingle(filePath: string, fileType?: string): Promise<Par
}
}
// Check if this is an external URL
if (filePath.startsWith('http://') || filePath.startsWith('https://')) {
return handleExternalUrl(filePath, fileType)
}
// Check if this is a cloud storage path (S3 or Blob)
const isS3Path = filePath.includes('/api/files/serve/s3/')
const isBlobPath = filePath.includes('/api/files/serve/blob/')
// Use cloud handler if it's a cloud path or we're in cloud mode
if (isS3Path || isBlobPath || isUsingCloudStorage()) {
return handleCloudFile(filePath, fileType)
}
// Use local handler for local files
return handleLocalFile(filePath, fileType)
}
/**
* Validate file path for security
* Validate file path for security - prevents null byte injection and path traversal attacks
*/
function validateFilePath(filePath: string): { isValid: boolean; error?: string } {
// Check for null bytes
if (filePath.includes('\0')) {
return { isValid: false, error: 'Invalid path: null byte detected' }
}
// Check for path traversal attempts
if (filePath.includes('..')) {
return { isValid: false, error: 'Access denied: path traversal detected' }
}
// Check for tilde characters (home directory access)
if (filePath.includes('~')) {
return { isValid: false, error: 'Invalid path: tilde character not allowed' }
}
// Check for absolute paths outside allowed directories
if (filePath.startsWith('/') && !filePath.startsWith('/api/files/serve/')) {
return { isValid: false, error: 'Path outside allowed directory' }
}
// Check for Windows absolute paths
if (/^[A-Za-z]:\\/.test(filePath)) {
return { isValid: false, error: 'Path outside allowed directory' }
}
@@ -260,12 +240,10 @@ async function handleExternalUrl(url: string, fileType?: string): Promise<ParseR
logger.info(`Downloaded file from URL: ${url}, size: ${buffer.length} bytes`)
// Extract filename from URL
const urlPath = new URL(url).pathname
const filename = urlPath.split('/').pop() || 'download'
const extension = path.extname(filename).toLowerCase().substring(1)
// Process the file based on its content type
if (extension === 'pdf') {
return await handlePdfBuffer(buffer, filename, fileType, url)
}
@@ -276,7 +254,6 @@ async function handleExternalUrl(url: string, fileType?: string): Promise<ParseR
return await handleGenericTextBuffer(buffer, filename, extension, fileType, url)
}
// For binary or unknown files
return handleGenericBuffer(buffer, filename, extension, fileType)
} catch (error) {
logger.error(`Error handling external URL ${url}:`, error)
@@ -289,35 +266,29 @@ async function handleExternalUrl(url: string, fileType?: string): Promise<ParseR
}
/**
* Handle file stored in cloud storage (S3 or Azure Blob)
* Handle file stored in cloud storage
*/
async function handleCloudFile(filePath: string, fileType?: string): Promise<ParseResult> {
try {
// Extract the cloud key from the path
let cloudKey: string
if (filePath.includes('/api/files/serve/s3/')) {
cloudKey = decodeURIComponent(filePath.split('/api/files/serve/s3/')[1])
} else if (filePath.includes('/api/files/serve/blob/')) {
cloudKey = decodeURIComponent(filePath.split('/api/files/serve/blob/')[1])
} else if (filePath.startsWith('/api/files/serve/')) {
// Backwards-compatibility: path like "/api/files/serve/<key>"
cloudKey = decodeURIComponent(filePath.substring('/api/files/serve/'.length))
} else {
// Assume raw key provided
cloudKey = filePath
}
logger.info('Extracted cloud key:', cloudKey)
// Download the file from cloud storage - this can throw for access errors
const fileBuffer = await downloadFile(cloudKey)
logger.info(`Downloaded file from cloud storage: ${cloudKey}, size: ${fileBuffer.length} bytes`)
// Extract the filename from the cloud key
const filename = cloudKey.split('/').pop() || cloudKey
const extension = path.extname(filename).toLowerCase().substring(1)
// Process the file based on its content type
if (extension === 'pdf') {
return await handlePdfBuffer(fileBuffer, filename, fileType, filePath)
}
@@ -325,22 +296,19 @@ async function handleCloudFile(filePath: string, fileType?: string): Promise<Par
return await handleCsvBuffer(fileBuffer, filename, fileType, filePath)
}
if (isSupportedFileType(extension)) {
// For other supported types that we have parsers for
return await handleGenericTextBuffer(fileBuffer, filename, extension, fileType, filePath)
}
// For binary or unknown files
return handleGenericBuffer(fileBuffer, filename, extension, fileType)
} catch (error) {
logger.error(`Error handling cloud file ${filePath}:`, error)
// Check if this is a download/access error that should trigger a 500 response
// For download/access errors, throw to trigger 500 response
const errorMessage = (error as Error).message
if (errorMessage.includes('Access denied') || errorMessage.includes('Forbidden')) {
// For access errors, throw to trigger 500 response
throw new Error(`Error accessing file from cloud storage: ${errorMessage}`)
}
// For other errors (parsing, processing), return success:false
// For other errors (parsing, processing), return success:false and an error message
return {
success: false,
error: `Error accessing file from cloud storage: ${errorMessage}`,
@@ -354,28 +322,23 @@ async function handleCloudFile(filePath: string, fileType?: string): Promise<Par
*/
async function handleLocalFile(filePath: string, fileType?: string): Promise<ParseResult> {
try {
// Extract filename from path
const filename = filePath.split('/').pop() || filePath
const fullPath = path.join(UPLOAD_DIR_SERVER, filename)
logger.info('Processing local file:', fullPath)
// Check if file exists
try {
await fsPromises.access(fullPath)
} catch {
throw new Error(`File not found: ${filename}`)
}
// Parse the file directly
const result = await parseFile(fullPath)
// Get file stats for metadata
const stats = await fsPromises.stat(fullPath)
const fileBuffer = await readFile(fullPath)
const hash = createHash('md5').update(fileBuffer).digest('hex')
// Extract file extension for type detection
const extension = path.extname(filename).toLowerCase().substring(1)
return {
@@ -386,7 +349,7 @@ async function handleLocalFile(filePath: string, fileType?: string): Promise<Par
fileType: fileType || getMimeType(extension),
size: stats.size,
hash,
processingTime: 0, // Will be set by caller
processingTime: 0,
},
}
} catch (error) {
@@ -425,15 +388,14 @@ async function handlePdfBuffer(
fileType: fileType || 'application/pdf',
size: fileBuffer.length,
hash: createHash('md5').update(fileBuffer).digest('hex'),
processingTime: 0, // Will be set by caller
processingTime: 0,
},
}
} catch (error) {
logger.error('Failed to parse PDF in memory:', error)
// Create fallback message for PDF parsing failure
const content = createPdfFailureMessage(
0, // We can't determine page count without parsing
0,
fileBuffer.length,
originalPath || filename,
(error as Error).message
@@ -447,7 +409,7 @@ async function handlePdfBuffer(
fileType: fileType || 'application/pdf',
size: fileBuffer.length,
hash: createHash('md5').update(fileBuffer).digest('hex'),
processingTime: 0, // Will be set by caller
processingTime: 0,
},
}
}
@@ -465,7 +427,6 @@ async function handleCsvBuffer(
try {
logger.info(`Parsing CSV in memory: ${filename}`)
// Use the parseBuffer function from our library
const { parseBuffer } = await import('@/lib/file-parsers')
const result = await parseBuffer(fileBuffer, 'csv')
@@ -477,7 +438,7 @@ async function handleCsvBuffer(
fileType: fileType || 'text/csv',
size: fileBuffer.length,
hash: createHash('md5').update(fileBuffer).digest('hex'),
processingTime: 0, // Will be set by caller
processingTime: 0,
},
}
} catch (error) {
@@ -490,7 +451,7 @@ async function handleCsvBuffer(
fileType: 'text/csv',
size: 0,
hash: '',
processingTime: 0, // Will be set by caller
processingTime: 0,
},
}
}
@@ -509,7 +470,6 @@ async function handleGenericTextBuffer(
try {
logger.info(`Parsing text file in memory: ${filename}`)
// Try to use a specialized parser if available
try {
const { parseBuffer, isSupportedFileType } = await import('@/lib/file-parsers')
@@ -524,7 +484,7 @@ async function handleGenericTextBuffer(
fileType: fileType || getMimeType(extension),
size: fileBuffer.length,
hash: createHash('md5').update(fileBuffer).digest('hex'),
processingTime: 0, // Will be set by caller
processingTime: 0,
},
}
}
@@ -532,7 +492,6 @@ async function handleGenericTextBuffer(
logger.warn('Specialized parser failed, falling back to generic parsing:', parserError)
}
// Fallback to generic text parsing
const content = fileBuffer.toString('utf-8')
return {
@@ -543,7 +502,7 @@ async function handleGenericTextBuffer(
fileType: fileType || getMimeType(extension),
size: fileBuffer.length,
hash: createHash('md5').update(fileBuffer).digest('hex'),
processingTime: 0, // Will be set by caller
processingTime: 0,
},
}
} catch (error) {
@@ -556,7 +515,7 @@ async function handleGenericTextBuffer(
fileType: 'text/plain',
size: 0,
hash: '',
processingTime: 0, // Will be set by caller
processingTime: 0,
},
}
}
@@ -584,7 +543,7 @@ function handleGenericBuffer(
fileType: fileType || getMimeType(extension),
size: fileBuffer.length,
hash: createHash('md5').update(fileBuffer).digest('hex'),
processingTime: 0, // Will be set by caller
processingTime: 0,
},
}
}
@@ -594,25 +553,11 @@ function handleGenericBuffer(
*/
async function parseBufferAsPdf(buffer: Buffer) {
try {
// Import parsers dynamically to avoid initialization issues in tests
// First try to use the main PDF parser
try {
const { PdfParser } = await import('@/lib/file-parsers/pdf-parser')
const parser = new PdfParser()
logger.info('Using main PDF parser for buffer')
const { PdfParser } = await import('@/lib/file-parsers/pdf-parser')
const parser = new PdfParser()
logger.info('Using main PDF parser for buffer')
if (parser.parseBuffer) {
return await parser.parseBuffer(buffer)
}
throw new Error('PDF parser does not support buffer parsing')
} catch (error) {
// Fallback to raw PDF parser
logger.warn('Main PDF parser failed, using raw parser for buffer:', error)
const { RawPdfParser } = await import('@/lib/file-parsers/raw-pdf-parser')
const rawParser = new RawPdfParser()
return await rawParser.parseBuffer(buffer)
}
return await parser.parseBuffer(buffer)
} catch (error) {
throw new Error(`PDF parsing failed: ${(error as Error).message}`)
}
@@ -655,7 +600,7 @@ Please use a PDF viewer for best results.`
}
/**
* Create error message for PDF parsing failure
* Create error message for PDF parsing failure and make it more readable
*/
function createPdfFailureMessage(
pageCount: number,

View File

@@ -0,0 +1,361 @@
import { PutObjectCommand } from '@aws-sdk/client-s3'
import { getSignedUrl } from '@aws-sdk/s3-request-presigner'
import { type NextRequest, NextResponse } from 'next/server'
import { v4 as uuidv4 } from 'uuid'
import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console/logger'
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
import {
BLOB_CHAT_CONFIG,
BLOB_CONFIG,
BLOB_COPILOT_CONFIG,
BLOB_KB_CONFIG,
S3_CHAT_CONFIG,
S3_CONFIG,
S3_COPILOT_CONFIG,
S3_KB_CONFIG,
} from '@/lib/uploads/setup'
import { validateFileType } from '@/lib/uploads/validation'
import { createErrorResponse, createOptionsResponse } from '@/app/api/files/utils'
const logger = createLogger('BatchPresignedUploadAPI')
interface BatchFileRequest {
fileName: string
contentType: string
fileSize: number
}
interface BatchPresignedUrlRequest {
files: BatchFileRequest[]
}
type UploadType = 'general' | 'knowledge-base' | 'chat' | 'copilot'
export async function POST(request: NextRequest) {
try {
const session = await getSession()
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
let data: BatchPresignedUrlRequest
try {
data = await request.json()
} catch {
return NextResponse.json({ error: 'Invalid JSON in request body' }, { status: 400 })
}
const { files } = data
if (!files || !Array.isArray(files) || files.length === 0) {
return NextResponse.json(
{ error: 'files array is required and cannot be empty' },
{ status: 400 }
)
}
if (files.length > 100) {
return NextResponse.json(
{ error: 'Cannot process more than 100 files at once' },
{ status: 400 }
)
}
const uploadTypeParam = request.nextUrl.searchParams.get('type')
const uploadType: UploadType =
uploadTypeParam === 'knowledge-base'
? 'knowledge-base'
: uploadTypeParam === 'chat'
? 'chat'
: uploadTypeParam === 'copilot'
? 'copilot'
: 'general'
const MAX_FILE_SIZE = 100 * 1024 * 1024
for (const file of files) {
if (!file.fileName?.trim()) {
return NextResponse.json({ error: 'fileName is required for all files' }, { status: 400 })
}
if (!file.contentType?.trim()) {
return NextResponse.json(
{ error: 'contentType is required for all files' },
{ status: 400 }
)
}
if (!file.fileSize || file.fileSize <= 0) {
return NextResponse.json(
{ error: 'fileSize must be positive for all files' },
{ status: 400 }
)
}
if (file.fileSize > MAX_FILE_SIZE) {
return NextResponse.json(
{ error: `File ${file.fileName} exceeds maximum size of ${MAX_FILE_SIZE} bytes` },
{ status: 400 }
)
}
if (uploadType === 'knowledge-base') {
const fileValidationError = validateFileType(file.fileName, file.contentType)
if (fileValidationError) {
return NextResponse.json(
{
error: fileValidationError.message,
code: fileValidationError.code,
supportedTypes: fileValidationError.supportedTypes,
},
{ status: 400 }
)
}
}
}
const sessionUserId = session.user.id
if (uploadType === 'copilot' && !sessionUserId?.trim()) {
return NextResponse.json(
{ error: 'Authenticated user session is required for copilot uploads' },
{ status: 400 }
)
}
if (!isUsingCloudStorage()) {
return NextResponse.json(
{ error: 'Direct uploads are only available when cloud storage is enabled' },
{ status: 400 }
)
}
const storageProvider = getStorageProvider()
logger.info(
`Generating batch ${uploadType} presigned URLs for ${files.length} files using ${storageProvider}`
)
const startTime = Date.now()
let result
switch (storageProvider) {
case 's3':
result = await handleBatchS3PresignedUrls(files, uploadType, sessionUserId)
break
case 'blob':
result = await handleBatchBlobPresignedUrls(files, uploadType, sessionUserId)
break
default:
return NextResponse.json(
{ error: `Unknown storage provider: ${storageProvider}` },
{ status: 500 }
)
}
const duration = Date.now() - startTime
logger.info(
`Generated ${files.length} presigned URLs in ${duration}ms (avg ${Math.round(duration / files.length)}ms per file)`
)
return NextResponse.json(result)
} catch (error) {
logger.error('Error generating batch presigned URLs:', error)
return createErrorResponse(
error instanceof Error ? error : new Error('Failed to generate batch presigned URLs')
)
}
}
async function handleBatchS3PresignedUrls(
files: BatchFileRequest[],
uploadType: UploadType,
userId?: string
) {
const config =
uploadType === 'knowledge-base'
? S3_KB_CONFIG
: uploadType === 'chat'
? S3_CHAT_CONFIG
: uploadType === 'copilot'
? S3_COPILOT_CONFIG
: S3_CONFIG
if (!config.bucket || !config.region) {
throw new Error(`S3 configuration missing for ${uploadType} uploads`)
}
const { getS3Client, sanitizeFilenameForMetadata } = await import('@/lib/uploads/s3/s3-client')
const s3Client = getS3Client()
let prefix = ''
if (uploadType === 'knowledge-base') {
prefix = 'kb/'
} else if (uploadType === 'chat') {
prefix = 'chat/'
} else if (uploadType === 'copilot') {
prefix = `${userId}/`
}
const baseMetadata: Record<string, string> = {
uploadedAt: new Date().toISOString(),
}
if (uploadType === 'knowledge-base') {
baseMetadata.purpose = 'knowledge-base'
} else if (uploadType === 'chat') {
baseMetadata.purpose = 'chat'
} else if (uploadType === 'copilot') {
baseMetadata.purpose = 'copilot'
baseMetadata.userId = userId || ''
}
const results = await Promise.all(
files.map(async (file) => {
const safeFileName = file.fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
const uniqueKey = `${prefix}${uuidv4()}-${safeFileName}`
const sanitizedOriginalName = sanitizeFilenameForMetadata(file.fileName)
const metadata = {
...baseMetadata,
originalName: sanitizedOriginalName,
}
const command = new PutObjectCommand({
Bucket: config.bucket,
Key: uniqueKey,
ContentType: file.contentType,
Metadata: metadata,
})
const presignedUrl = await getSignedUrl(s3Client, command, { expiresIn: 3600 })
const finalPath =
uploadType === 'chat'
? `https://${config.bucket}.s3.${config.region}.amazonaws.com/${uniqueKey}`
: `/api/files/serve/s3/${encodeURIComponent(uniqueKey)}`
return {
fileName: file.fileName,
presignedUrl,
fileInfo: {
path: finalPath,
key: uniqueKey,
name: file.fileName,
size: file.fileSize,
type: file.contentType,
},
}
})
)
return {
files: results,
directUploadSupported: true,
}
}
async function handleBatchBlobPresignedUrls(
files: BatchFileRequest[],
uploadType: UploadType,
userId?: string
) {
const config =
uploadType === 'knowledge-base'
? BLOB_KB_CONFIG
: uploadType === 'chat'
? BLOB_CHAT_CONFIG
: uploadType === 'copilot'
? BLOB_COPILOT_CONFIG
: BLOB_CONFIG
if (
!config.accountName ||
!config.containerName ||
(!config.accountKey && !config.connectionString)
) {
throw new Error(`Azure Blob configuration missing for ${uploadType} uploads`)
}
const { getBlobServiceClient } = await import('@/lib/uploads/blob/blob-client')
const { BlobSASPermissions, generateBlobSASQueryParameters, StorageSharedKeyCredential } =
await import('@azure/storage-blob')
const blobServiceClient = getBlobServiceClient()
const containerClient = blobServiceClient.getContainerClient(config.containerName)
let prefix = ''
if (uploadType === 'knowledge-base') {
prefix = 'kb/'
} else if (uploadType === 'chat') {
prefix = 'chat/'
} else if (uploadType === 'copilot') {
prefix = `${userId}/`
}
const baseUploadHeaders: Record<string, string> = {
'x-ms-blob-type': 'BlockBlob',
'x-ms-meta-uploadedat': new Date().toISOString(),
}
if (uploadType === 'knowledge-base') {
baseUploadHeaders['x-ms-meta-purpose'] = 'knowledge-base'
} else if (uploadType === 'chat') {
baseUploadHeaders['x-ms-meta-purpose'] = 'chat'
} else if (uploadType === 'copilot') {
baseUploadHeaders['x-ms-meta-purpose'] = 'copilot'
baseUploadHeaders['x-ms-meta-userid'] = encodeURIComponent(userId || '')
}
const results = await Promise.all(
files.map(async (file) => {
const safeFileName = file.fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
const uniqueKey = `${prefix}${uuidv4()}-${safeFileName}`
const blockBlobClient = containerClient.getBlockBlobClient(uniqueKey)
const sasOptions = {
containerName: config.containerName,
blobName: uniqueKey,
permissions: BlobSASPermissions.parse('w'),
startsOn: new Date(),
expiresOn: new Date(Date.now() + 3600 * 1000),
}
const sasToken = generateBlobSASQueryParameters(
sasOptions,
new StorageSharedKeyCredential(config.accountName, config.accountKey || '')
).toString()
const presignedUrl = `${blockBlobClient.url}?${sasToken}`
const finalPath =
uploadType === 'chat'
? blockBlobClient.url
: `/api/files/serve/blob/${encodeURIComponent(uniqueKey)}`
const uploadHeaders = {
...baseUploadHeaders,
'x-ms-blob-content-type': file.contentType,
'x-ms-meta-originalname': encodeURIComponent(file.fileName),
}
return {
fileName: file.fileName,
presignedUrl,
fileInfo: {
path: finalPath,
key: uniqueKey,
name: file.fileName,
size: file.fileSize,
type: file.contentType,
},
uploadHeaders,
}
})
)
return {
files: results,
directUploadSupported: true,
}
}
export async function OPTIONS() {
return createOptionsResponse()
}

View File

@@ -1,7 +1,13 @@
import { NextRequest } from 'next/server'
import { beforeEach, describe, expect, test, vi } from 'vitest'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import { setupFileApiMocks } from '@/app/api/__test-utils__/utils'
/**
* Tests for file presigned API route
*
* @vitest-environment node
*/
describe('/api/files/presigned', () => {
beforeEach(() => {
vi.clearAllMocks()
@@ -19,7 +25,7 @@ describe('/api/files/presigned', () => {
})
describe('POST', () => {
test('should return error when cloud storage is not enabled', async () => {
it('should return error when cloud storage is not enabled', async () => {
setupFileApiMocks({
cloudEnabled: false,
storageProvider: 's3',
@@ -39,7 +45,7 @@ describe('/api/files/presigned', () => {
const response = await POST(request)
const data = await response.json()
expect(response.status).toBe(500) // Changed from 400 to 500 (StorageConfigError)
expect(response.status).toBe(500)
expect(data.error).toBe('Direct uploads are only available when cloud storage is enabled')
expect(data.code).toBe('STORAGE_CONFIG_ERROR')
expect(data.directUploadSupported).toBe(false)

View File

@@ -5,6 +5,7 @@ import { v4 as uuidv4 } from 'uuid'
import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console/logger'
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
import { isImageFileType } from '@/lib/uploads/file-utils'
// Dynamic imports for storage clients to avoid client-side bundling
import {
BLOB_CHAT_CONFIG,
@@ -16,6 +17,7 @@ import {
S3_COPILOT_CONFIG,
S3_KB_CONFIG,
} from '@/lib/uploads/setup'
import { validateFileType } from '@/lib/uploads/validation'
import { createErrorResponse, createOptionsResponse } from '@/app/api/files/utils'
const logger = createLogger('PresignedUploadAPI')
@@ -96,6 +98,13 @@ export async function POST(request: NextRequest) {
? 'copilot'
: 'general'
if (uploadType === 'knowledge-base') {
const fileValidationError = validateFileType(fileName, contentType)
if (fileValidationError) {
throw new ValidationError(`${fileValidationError.message}`)
}
}
// Evaluate user id from session for copilot uploads
const sessionUserId = session.user.id
@@ -104,6 +113,12 @@ export async function POST(request: NextRequest) {
if (!sessionUserId?.trim()) {
throw new ValidationError('Authenticated user session is required for copilot uploads')
}
// Only allow image uploads for copilot
if (!isImageFileType(contentType)) {
throw new ValidationError(
'Only image files (JPEG, PNG, GIF, WebP, SVG) are allowed for copilot uploads'
)
}
}
if (!isUsingCloudStorage()) {
@@ -224,10 +239,9 @@ async function handleS3PresignedUrl(
)
}
// For chat images, use direct S3 URLs since they need to be permanently accessible
// For other files, use serve path for access control
// For chat images and knowledge base files, use direct URLs since they need to be accessible by external services
const finalPath =
uploadType === 'chat'
uploadType === 'chat' || uploadType === 'knowledge-base'
? `https://${config.bucket}.s3.${config.region}.amazonaws.com/${uniqueKey}`
: `/api/files/serve/s3/${encodeURIComponent(uniqueKey)}`

View File

@@ -2,7 +2,7 @@ import { readFile } from 'fs/promises'
import type { NextRequest, NextResponse } from 'next/server'
import { createLogger } from '@/lib/logs/console/logger'
import { downloadFile, getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
import { BLOB_KB_CONFIG, S3_KB_CONFIG } from '@/lib/uploads/setup'
import { S3_KB_CONFIG } from '@/lib/uploads/setup'
import '@/lib/uploads/setup.server'
import {
@@ -15,19 +15,6 @@ import {
const logger = createLogger('FilesServeAPI')
async function streamToBuffer(readableStream: NodeJS.ReadableStream): Promise<Buffer> {
return new Promise((resolve, reject) => {
const chunks: Buffer[] = []
readableStream.on('data', (data) => {
chunks.push(data instanceof Buffer ? data : Buffer.from(data))
})
readableStream.on('end', () => {
resolve(Buffer.concat(chunks))
})
readableStream.on('error', reject)
})
}
/**
* Main API route handler for serving files
*/
@@ -102,49 +89,23 @@ async function handleLocalFile(filename: string): Promise<NextResponse> {
}
async function downloadKBFile(cloudKey: string): Promise<Buffer> {
logger.info(`Downloading KB file: ${cloudKey}`)
const storageProvider = getStorageProvider()
if (storageProvider === 'blob') {
logger.info(`Downloading KB file from Azure Blob Storage: ${cloudKey}`)
// Use KB-specific blob configuration
const { getBlobServiceClient } = await import('@/lib/uploads/blob/blob-client')
const blobServiceClient = getBlobServiceClient()
const containerClient = blobServiceClient.getContainerClient(BLOB_KB_CONFIG.containerName)
const blockBlobClient = containerClient.getBlockBlobClient(cloudKey)
const downloadBlockBlobResponse = await blockBlobClient.download()
if (!downloadBlockBlobResponse.readableStreamBody) {
throw new Error('Failed to get readable stream from blob download')
}
// Convert stream to buffer
return await streamToBuffer(downloadBlockBlobResponse.readableStreamBody)
const { BLOB_KB_CONFIG } = await import('@/lib/uploads/setup')
return downloadFile(cloudKey, {
containerName: BLOB_KB_CONFIG.containerName,
accountName: BLOB_KB_CONFIG.accountName,
accountKey: BLOB_KB_CONFIG.accountKey,
connectionString: BLOB_KB_CONFIG.connectionString,
})
}
if (storageProvider === 's3') {
logger.info(`Downloading KB file from S3: ${cloudKey}`)
// Use KB-specific S3 configuration
const { getS3Client } = await import('@/lib/uploads/s3/s3-client')
const { GetObjectCommand } = await import('@aws-sdk/client-s3')
const s3Client = getS3Client()
const command = new GetObjectCommand({
Bucket: S3_KB_CONFIG.bucket,
Key: cloudKey,
})
const response = await s3Client.send(command)
if (!response.Body) {
throw new Error('No body in S3 response')
}
// Convert stream to buffer using the same method as the regular S3 client
const stream = response.Body as any
return new Promise<Buffer>((resolve, reject) => {
const chunks: Buffer[] = []
stream.on('data', (chunk: Buffer) => chunks.push(chunk))
stream.on('end', () => resolve(Buffer.concat(chunks)))
stream.on('error', reject)
return downloadFile(cloudKey, {
bucket: S3_KB_CONFIG.bucket,
region: S3_KB_CONFIG.region,
})
}
@@ -167,17 +128,22 @@ async function handleCloudProxy(
if (isKBFile) {
fileBuffer = await downloadKBFile(cloudKey)
} else if (bucketType === 'copilot') {
// Download from copilot-specific bucket
const storageProvider = getStorageProvider()
if (storageProvider === 's3') {
const { downloadFromS3WithConfig } = await import('@/lib/uploads/s3/s3-client')
const { S3_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
fileBuffer = await downloadFromS3WithConfig(cloudKey, S3_COPILOT_CONFIG)
fileBuffer = await downloadFile(cloudKey, {
bucket: S3_COPILOT_CONFIG.bucket,
region: S3_COPILOT_CONFIG.region,
})
} else if (storageProvider === 'blob') {
// For Azure Blob, use the default downloadFile for now
// TODO: Add downloadFromBlobWithConfig when needed
fileBuffer = await downloadFile(cloudKey)
const { BLOB_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
fileBuffer = await downloadFile(cloudKey, {
containerName: BLOB_COPILOT_CONFIG.containerName,
accountName: BLOB_COPILOT_CONFIG.accountName,
accountKey: BLOB_COPILOT_CONFIG.accountKey,
connectionString: BLOB_COPILOT_CONFIG.connectionString,
})
} else {
fileBuffer = await downloadFile(cloudKey)
}

View File

@@ -186,3 +186,190 @@ describe('File Upload API Route', () => {
expect(response.headers.get('Access-Control-Allow-Headers')).toBe('Content-Type')
})
})
describe('File Upload Security Tests', () => {
beforeEach(() => {
vi.resetModules()
vi.clearAllMocks()
vi.doMock('@/lib/auth', () => ({
getSession: vi.fn().mockResolvedValue({
user: { id: 'test-user-id' },
}),
}))
vi.doMock('@/lib/uploads', () => ({
isUsingCloudStorage: vi.fn().mockReturnValue(false),
uploadFile: vi.fn().mockResolvedValue({
key: 'test-key',
path: '/test/path',
}),
}))
vi.doMock('@/lib/uploads/setup.server', () => ({}))
})
afterEach(() => {
vi.clearAllMocks()
})
describe('File Extension Validation', () => {
it('should accept allowed file types', async () => {
const allowedTypes = [
'pdf',
'doc',
'docx',
'txt',
'md',
'png',
'jpg',
'jpeg',
'gif',
'csv',
'xlsx',
'xls',
]
for (const ext of allowedTypes) {
const formData = new FormData()
const file = new File(['test content'], `test.${ext}`, { type: 'application/octet-stream' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(200)
}
})
it('should reject HTML files to prevent XSS', async () => {
const formData = new FormData()
const maliciousContent = '<script>alert("XSS")</script>'
const file = new File([maliciousContent], 'malicious.html', { type: 'text/html' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(400)
const data = await response.json()
expect(data.message).toContain("File type 'html' is not allowed")
})
it('should reject SVG files to prevent XSS', async () => {
const formData = new FormData()
const maliciousSvg = '<svg onload="alert(\'XSS\')" xmlns="http://www.w3.org/2000/svg"></svg>'
const file = new File([maliciousSvg], 'malicious.svg', { type: 'image/svg+xml' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(400)
const data = await response.json()
expect(data.message).toContain("File type 'svg' is not allowed")
})
it('should reject JavaScript files', async () => {
const formData = new FormData()
const maliciousJs = 'alert("XSS")'
const file = new File([maliciousJs], 'malicious.js', { type: 'application/javascript' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(400)
const data = await response.json()
expect(data.message).toContain("File type 'js' is not allowed")
})
it('should reject files without extensions', async () => {
const formData = new FormData()
const file = new File(['test content'], 'noextension', { type: 'application/octet-stream' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(400)
const data = await response.json()
expect(data.message).toContain("File type 'noextension' is not allowed")
})
it('should handle multiple files with mixed valid/invalid types', async () => {
const formData = new FormData()
// Valid file
const validFile = new File(['valid content'], 'valid.pdf', { type: 'application/pdf' })
formData.append('file', validFile)
// Invalid file (should cause rejection of entire request)
const invalidFile = new File(['<script>alert("XSS")</script>'], 'malicious.html', {
type: 'text/html',
})
formData.append('file', invalidFile)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(400)
const data = await response.json()
expect(data.message).toContain("File type 'html' is not allowed")
})
})
describe('Authentication Requirements', () => {
it('should reject uploads without authentication', async () => {
vi.doMock('@/lib/auth', () => ({
getSession: vi.fn().mockResolvedValue(null),
}))
const formData = new FormData()
const file = new File(['test content'], 'test.pdf', { type: 'application/pdf' })
formData.append('file', file)
const req = new Request('http://localhost/api/files/upload', {
method: 'POST',
body: formData,
})
const { POST } = await import('@/app/api/files/upload/route')
const response = await POST(req as any)
expect(response.status).toBe(401)
const data = await response.json()
expect(data.error).toBe('Unauthorized')
})
})
})

View File

@@ -9,6 +9,34 @@ import {
InvalidRequestError,
} from '@/app/api/files/utils'
// Allowlist of permitted file extensions for security
const ALLOWED_EXTENSIONS = new Set([
// Documents
'pdf',
'doc',
'docx',
'txt',
'md',
// Images (safe formats)
'png',
'jpg',
'jpeg',
'gif',
// Data files
'csv',
'xlsx',
'xls',
])
/**
* Validates file extension against allowlist
*/
function validateFileExtension(filename: string): boolean {
const extension = filename.split('.').pop()?.toLowerCase()
if (!extension) return false
return ALLOWED_EXTENSIONS.has(extension)
}
export const dynamic = 'force-dynamic'
const logger = createLogger('FilesUploadAPI')
@@ -49,6 +77,14 @@ export async function POST(request: NextRequest) {
// Process each file
for (const file of files) {
const originalName = file.name
if (!validateFileExtension(originalName)) {
const extension = originalName.split('.').pop()?.toLowerCase() || 'unknown'
throw new InvalidRequestError(
`File type '${extension}' is not allowed. Allowed types: ${Array.from(ALLOWED_EXTENSIONS).join(', ')}`
)
}
const bytes = await file.arrayBuffer()
const buffer = Buffer.from(bytes)

View File

@@ -0,0 +1,327 @@
import { describe, expect, it } from 'vitest'
import { createFileResponse, extractFilename } from './utils'
describe('extractFilename', () => {
describe('legitimate file paths', () => {
it('should extract filename from standard serve path', () => {
expect(extractFilename('/api/files/serve/test-file.txt')).toBe('test-file.txt')
})
it('should extract filename from serve path with special characters', () => {
expect(extractFilename('/api/files/serve/document-with-dashes_and_underscores.pdf')).toBe(
'document-with-dashes_and_underscores.pdf'
)
})
it('should handle simple filename without serve path', () => {
expect(extractFilename('simple-file.txt')).toBe('simple-file.txt')
})
it('should extract last segment from nested path', () => {
expect(extractFilename('nested/path/file.txt')).toBe('file.txt')
})
})
describe('cloud storage paths', () => {
it('should preserve S3 path structure', () => {
expect(extractFilename('/api/files/serve/s3/1234567890-test-file.txt')).toBe(
's3/1234567890-test-file.txt'
)
})
it('should preserve S3 path with nested folders', () => {
expect(extractFilename('/api/files/serve/s3/folder/subfolder/document.pdf')).toBe(
's3/folder/subfolder/document.pdf'
)
})
it('should preserve Azure Blob path structure', () => {
expect(extractFilename('/api/files/serve/blob/1234567890-test-document.pdf')).toBe(
'blob/1234567890-test-document.pdf'
)
})
it('should preserve Blob path with nested folders', () => {
expect(extractFilename('/api/files/serve/blob/uploads/user-files/report.xlsx')).toBe(
'blob/uploads/user-files/report.xlsx'
)
})
})
describe('security - path traversal prevention', () => {
it('should sanitize basic path traversal attempt', () => {
expect(extractFilename('/api/files/serve/../config.txt')).toBe('config.txt')
})
it('should sanitize deep path traversal attempt', () => {
expect(extractFilename('/api/files/serve/../../../../../etc/passwd')).toBe('etcpasswd')
})
it('should sanitize multiple path traversal patterns', () => {
expect(extractFilename('/api/files/serve/../../secret.txt')).toBe('secret.txt')
})
it('should sanitize path traversal with forward slashes', () => {
expect(extractFilename('/api/files/serve/../../../system/file')).toBe('systemfile')
})
it('should sanitize mixed path traversal patterns', () => {
expect(extractFilename('/api/files/serve/../folder/../file.txt')).toBe('folderfile.txt')
})
it('should remove directory separators from local filenames', () => {
expect(extractFilename('/api/files/serve/folder/with/separators.txt')).toBe(
'folderwithseparators.txt'
)
})
it('should handle backslash path separators (Windows style)', () => {
expect(extractFilename('/api/files/serve/folder\\file.txt')).toBe('folderfile.txt')
})
})
describe('cloud storage path traversal prevention', () => {
it('should sanitize S3 path traversal attempts while preserving structure', () => {
expect(extractFilename('/api/files/serve/s3/../config')).toBe('s3/config')
})
it('should sanitize S3 path with nested traversal attempts', () => {
expect(extractFilename('/api/files/serve/s3/folder/../sensitive/../file.txt')).toBe(
's3/folder/sensitive/file.txt'
)
})
it('should sanitize Blob path traversal attempts while preserving structure', () => {
expect(extractFilename('/api/files/serve/blob/../system.txt')).toBe('blob/system.txt')
})
it('should remove leading dots from cloud path segments', () => {
expect(extractFilename('/api/files/serve/s3/.hidden/../file.txt')).toBe('s3/hidden/file.txt')
})
})
describe('edge cases and error handling', () => {
it('should handle filename with dots (but not traversal)', () => {
expect(extractFilename('/api/files/serve/file.with.dots.txt')).toBe('file.with.dots.txt')
})
it('should handle filename with multiple extensions', () => {
expect(extractFilename('/api/files/serve/archive.tar.gz')).toBe('archive.tar.gz')
})
it('should throw error for empty filename after sanitization', () => {
expect(() => extractFilename('/api/files/serve/')).toThrow(
'Invalid or empty filename after sanitization'
)
})
it('should throw error for filename that becomes empty after path traversal removal', () => {
expect(() => extractFilename('/api/files/serve/../..')).toThrow(
'Invalid or empty filename after sanitization'
)
})
it('should handle single character filenames', () => {
expect(extractFilename('/api/files/serve/a')).toBe('a')
})
it('should handle numeric filenames', () => {
expect(extractFilename('/api/files/serve/123')).toBe('123')
})
})
describe('backward compatibility', () => {
it('should match old behavior for legitimate local files', () => {
// These test cases verify that our security fix maintains exact backward compatibility
// for all legitimate use cases found in the existing codebase
expect(extractFilename('/api/files/serve/test-file.txt')).toBe('test-file.txt')
expect(extractFilename('/api/files/serve/nonexistent.txt')).toBe('nonexistent.txt')
})
it('should match old behavior for legitimate cloud files', () => {
// These test cases are from the actual delete route tests
expect(extractFilename('/api/files/serve/s3/1234567890-test-file.txt')).toBe(
's3/1234567890-test-file.txt'
)
expect(extractFilename('/api/files/serve/blob/1234567890-test-document.pdf')).toBe(
'blob/1234567890-test-document.pdf'
)
})
it('should match old behavior for simple paths', () => {
// These match the mock implementations in serve route tests
expect(extractFilename('simple-file.txt')).toBe('simple-file.txt')
expect(extractFilename('nested/path/file.txt')).toBe('file.txt')
})
})
describe('File Serving Security Tests', () => {
describe('createFileResponse security headers', () => {
it('should serve safe images inline with proper headers', () => {
const response = createFileResponse({
buffer: Buffer.from('fake-image-data'),
contentType: 'image/png',
filename: 'safe-image.png',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('image/png')
expect(response.headers.get('Content-Disposition')).toBe(
'inline; filename="safe-image.png"'
)
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
expect(response.headers.get('Content-Security-Policy')).toBe(
"default-src 'none'; style-src 'unsafe-inline'; sandbox;"
)
})
it('should serve PDFs inline safely', () => {
const response = createFileResponse({
buffer: Buffer.from('fake-pdf-data'),
contentType: 'application/pdf',
filename: 'document.pdf',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/pdf')
expect(response.headers.get('Content-Disposition')).toBe('inline; filename="document.pdf"')
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
})
it('should force attachment for HTML files to prevent XSS', () => {
const response = createFileResponse({
buffer: Buffer.from('<script>alert("XSS")</script>'),
contentType: 'text/html',
filename: 'malicious.html',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="malicious.html"'
)
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
})
it('should force attachment for SVG files to prevent XSS', () => {
const response = createFileResponse({
buffer: Buffer.from(
'<svg onload="alert(\'XSS\')" xmlns="http://www.w3.org/2000/svg"></svg>'
),
contentType: 'image/svg+xml',
filename: 'malicious.svg',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="malicious.svg"'
)
})
it('should override dangerous content types to safe alternatives', () => {
const response = createFileResponse({
buffer: Buffer.from('<svg>safe content</svg>'),
contentType: 'image/svg+xml',
filename: 'image.png', // Extension doesn't match content-type
})
expect(response.status).toBe(200)
// Should override SVG content type to plain text for safety
expect(response.headers.get('Content-Type')).toBe('text/plain')
expect(response.headers.get('Content-Disposition')).toBe('inline; filename="image.png"')
})
it('should force attachment for JavaScript files', () => {
const response = createFileResponse({
buffer: Buffer.from('alert("XSS")'),
contentType: 'application/javascript',
filename: 'malicious.js',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="malicious.js"'
)
})
it('should force attachment for CSS files', () => {
const response = createFileResponse({
buffer: Buffer.from('body { background: url(javascript:alert("XSS")) }'),
contentType: 'text/css',
filename: 'malicious.css',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="malicious.css"'
)
})
it('should force attachment for XML files', () => {
const response = createFileResponse({
buffer: Buffer.from('<?xml version="1.0"?><root><script>alert("XSS")</script></root>'),
contentType: 'application/xml',
filename: 'malicious.xml',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="malicious.xml"'
)
})
it('should serve text files safely', () => {
const response = createFileResponse({
buffer: Buffer.from('Safe text content'),
contentType: 'text/plain',
filename: 'document.txt',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('text/plain')
expect(response.headers.get('Content-Disposition')).toBe('inline; filename="document.txt"')
})
it('should force attachment for unknown/unsafe content types', () => {
const response = createFileResponse({
buffer: Buffer.from('unknown content'),
contentType: 'application/unknown',
filename: 'unknown.bin',
})
expect(response.status).toBe(200)
expect(response.headers.get('Content-Type')).toBe('application/unknown')
expect(response.headers.get('Content-Disposition')).toBe(
'attachment; filename="unknown.bin"'
)
})
})
describe('Content Security Policy', () => {
it('should include CSP header in all responses', () => {
const response = createFileResponse({
buffer: Buffer.from('test'),
contentType: 'text/plain',
filename: 'test.txt',
})
const csp = response.headers.get('Content-Security-Policy')
expect(csp).toBe("default-src 'none'; style-src 'unsafe-inline'; sandbox;")
})
it('should include X-Content-Type-Options header', () => {
const response = createFileResponse({
buffer: Buffer.from('test'),
contentType: 'text/plain',
filename: 'test.txt',
})
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
})
})
})
})

View File

@@ -70,7 +70,6 @@ export const contentTypeMap: Record<string, string> = {
jpg: 'image/jpeg',
jpeg: 'image/jpeg',
gif: 'image/gif',
svg: 'image/svg+xml',
// Archive formats
zip: 'application/zip',
// Folder format
@@ -153,10 +152,43 @@ export function extractBlobKey(path: string): string {
* Extract filename from a serve path
*/
export function extractFilename(path: string): string {
let filename: string
if (path.startsWith('/api/files/serve/')) {
return path.substring('/api/files/serve/'.length)
filename = path.substring('/api/files/serve/'.length)
} else {
filename = path.split('/').pop() || path
}
return path.split('/').pop() || path
filename = filename
.replace(/\.\./g, '')
.replace(/\/\.\./g, '')
.replace(/\.\.\//g, '')
// Handle cloud storage paths (s3/key, blob/key) - preserve forward slashes for these
if (filename.startsWith('s3/') || filename.startsWith('blob/')) {
// For cloud paths, only sanitize the key portion after the prefix
const parts = filename.split('/')
const prefix = parts[0] // 's3' or 'blob'
const keyParts = parts.slice(1)
// Sanitize each part of the key to prevent traversal
const sanitizedKeyParts = keyParts
.map((part) => part.replace(/\.\./g, '').replace(/^\./g, '').trim())
.filter((part) => part.length > 0)
filename = `${prefix}/${sanitizedKeyParts.join('/')}`
} else {
// For regular filenames, remove any remaining path separators
filename = filename.replace(/[/\\]/g, '')
}
// Additional validation: ensure filename is not empty after sanitization
if (!filename || filename.trim().length === 0) {
throw new Error('Invalid or empty filename after sanitization')
}
return filename
}
/**
@@ -174,16 +206,65 @@ export function findLocalFile(filename: string): string | null {
return null
}
const SAFE_INLINE_TYPES = new Set([
'image/png',
'image/jpeg',
'image/jpg',
'image/gif',
'application/pdf',
'text/plain',
'text/csv',
'application/json',
])
// File extensions that should always be served as attachment for security
const FORCE_ATTACHMENT_EXTENSIONS = new Set(['html', 'htm', 'svg', 'js', 'css', 'xml'])
/**
* Create a file response with appropriate headers
* Determines safe content type and disposition for file serving
*/
function getSecureFileHeaders(filename: string, originalContentType: string) {
const extension = filename.split('.').pop()?.toLowerCase() || ''
// Force attachment for potentially dangerous file types
if (FORCE_ATTACHMENT_EXTENSIONS.has(extension)) {
return {
contentType: 'application/octet-stream', // Force download
disposition: 'attachment',
}
}
// Override content type for safety while preserving legitimate use cases
let safeContentType = originalContentType
// Handle potentially dangerous content types
if (originalContentType === 'text/html' || originalContentType === 'image/svg+xml') {
safeContentType = 'text/plain' // Prevent browser rendering
}
// Use inline only for verified safe content types
const disposition = SAFE_INLINE_TYPES.has(safeContentType) ? 'inline' : 'attachment'
return {
contentType: safeContentType,
disposition,
}
}
/**
* Create a file response with appropriate security headers
*/
export function createFileResponse(file: FileResponse): NextResponse {
const { contentType, disposition } = getSecureFileHeaders(file.filename, file.contentType)
return new NextResponse(file.buffer as BodyInit, {
status: 200,
headers: {
'Content-Type': file.contentType,
'Content-Disposition': `inline; filename="${file.filename}"`,
'Content-Type': contentType,
'Content-Disposition': `${disposition}; filename="${file.filename}"`,
'Cache-Control': 'public, max-age=31536000', // Cache for 1 year
'X-Content-Type-Options': 'nosniff',
'Content-Security-Policy': "default-src 'none'; style-src 'unsafe-inline'; sandbox;",
},
})
}

View File

@@ -32,6 +32,14 @@ describe('Function Execute API Route', () => {
createLogger: vi.fn().mockReturnValue(mockLogger),
}))
vi.doMock('@/lib/execution/e2b', () => ({
executeInE2B: vi.fn().mockResolvedValue({
result: 'e2b success',
stdout: 'e2b output',
sandboxId: 'test-sandbox-id',
}),
}))
mockRunInContext.mockResolvedValue('vm success')
mockCreateContext.mockReturnValue({})
})
@@ -45,6 +53,7 @@ describe('Function Execute API Route', () => {
const req = createMockRequest('POST', {
code: 'return "Hello World"',
timeout: 5000,
useLocalVM: true,
})
const { POST } = await import('@/app/api/function/execute/route')
@@ -74,6 +83,7 @@ describe('Function Execute API Route', () => {
it('should use default timeout when not provided', async () => {
const req = createMockRequest('POST', {
code: 'return "test"',
useLocalVM: true,
})
const { POST } = await import('@/app/api/function/execute/route')
@@ -93,6 +103,7 @@ describe('Function Execute API Route', () => {
it('should resolve environment variables with {{var_name}} syntax', async () => {
const req = createMockRequest('POST', {
code: 'return {{API_KEY}}',
useLocalVM: true,
envVars: {
API_KEY: 'secret-key-123',
},
@@ -108,6 +119,7 @@ describe('Function Execute API Route', () => {
it('should resolve tag variables with <tag_name> syntax', async () => {
const req = createMockRequest('POST', {
code: 'return <email>',
useLocalVM: true,
params: {
email: { id: '123', subject: 'Test Email' },
},
@@ -123,6 +135,7 @@ describe('Function Execute API Route', () => {
it('should NOT treat email addresses as template variables', async () => {
const req = createMockRequest('POST', {
code: 'return "Email sent to user"',
useLocalVM: true,
params: {
email: {
from: 'Waleed Latif <waleed@sim.ai>',
@@ -141,6 +154,7 @@ describe('Function Execute API Route', () => {
it('should only match valid variable names in angle brackets', async () => {
const req = createMockRequest('POST', {
code: 'return <validVar> + "<invalid@email.com>" + <another_valid>',
useLocalVM: true,
params: {
validVar: 'hello',
another_valid: 'world',
@@ -178,6 +192,7 @@ describe('Function Execute API Route', () => {
const req = createMockRequest('POST', {
code: 'return <email>',
useLocalVM: true,
params: gmailData,
})
@@ -200,6 +215,7 @@ describe('Function Execute API Route', () => {
const req = createMockRequest('POST', {
code: 'return <email>',
useLocalVM: true,
params: complexEmailData,
})
@@ -214,6 +230,7 @@ describe('Function Execute API Route', () => {
it('should handle custom tool execution with direct parameter access', async () => {
const req = createMockRequest('POST', {
code: 'return location + " weather is sunny"',
useLocalVM: true,
params: {
location: 'San Francisco',
},
@@ -245,6 +262,7 @@ describe('Function Execute API Route', () => {
it('should handle timeout parameter', async () => {
const req = createMockRequest('POST', {
code: 'return "test"',
useLocalVM: true,
timeout: 10000,
})
@@ -262,6 +280,7 @@ describe('Function Execute API Route', () => {
it('should handle empty parameters object', async () => {
const req = createMockRequest('POST', {
code: 'return "no params"',
useLocalVM: true,
params: {},
})
@@ -295,6 +314,7 @@ SyntaxError: Invalid or unexpected token
const req = createMockRequest('POST', {
code: 'const obj = {\n name: "test",\n description: "This has a missing closing quote\n};\nreturn obj;',
useLocalVM: true,
timeout: 5000,
})
@@ -338,6 +358,7 @@ SyntaxError: Invalid or unexpected token
const req = createMockRequest('POST', {
code: 'const obj = null;\nreturn obj.someMethod();',
useLocalVM: true,
timeout: 5000,
})
@@ -379,6 +400,7 @@ SyntaxError: Invalid or unexpected token
const req = createMockRequest('POST', {
code: 'const x = 42;\nreturn undefinedVariable + x;',
useLocalVM: true,
timeout: 5000,
})
@@ -409,6 +431,7 @@ SyntaxError: Invalid or unexpected token
const req = createMockRequest('POST', {
code: 'return "test";',
useLocalVM: true,
timeout: 5000,
})
@@ -445,6 +468,7 @@ SyntaxError: Invalid or unexpected token
const req = createMockRequest('POST', {
code: 'const a = 1;\nconst b = 2;\nconst c = 3;\nconst d = 4;\nreturn a + b + c + d;',
useLocalVM: true,
timeout: 5000,
})
@@ -476,6 +500,7 @@ SyntaxError: Invalid or unexpected token
const req = createMockRequest('POST', {
code: 'const obj = {\n name: "test"\n// Missing closing brace',
useLocalVM: true,
timeout: 5000,
})
@@ -496,6 +521,7 @@ SyntaxError: Invalid or unexpected token
// This tests the escapeRegExp function indirectly
const req = createMockRequest('POST', {
code: 'return {{special.chars+*?}}',
useLocalVM: true,
envVars: {
'special.chars+*?': 'escaped-value',
},
@@ -512,6 +538,7 @@ SyntaxError: Invalid or unexpected token
// Test with complex but not circular data first
const req = createMockRequest('POST', {
code: 'return <complexData>',
useLocalVM: true,
params: {
complexData: {
special: 'chars"with\'quotes',

View File

@@ -1,5 +1,8 @@
import { createContext, Script } from 'vm'
import { type NextRequest, NextResponse } from 'next/server'
import { env, isTruthy } from '@/lib/env'
import { executeInE2B } from '@/lib/execution/e2b'
import { CodeLanguage, DEFAULT_CODE_LANGUAGE, isValidCodeLanguage } from '@/lib/execution/languages'
import { createLogger } from '@/lib/logs/console/logger'
export const dynamic = 'force-dynamic'
@@ -8,6 +11,10 @@ export const maxDuration = 60
const logger = createLogger('FunctionExecuteAPI')
// Constants for E2B code wrapping line counts
const E2B_JS_WRAPPER_LINES = 3 // Lines before user code: ';(async () => {', ' try {', ' const __sim_result = await (async () => {'
const E2B_PYTHON_WRAPPER_LINES = 1 // Lines before user code: 'def __sim_main__():'
/**
* Enhanced error information interface
*/
@@ -124,6 +131,103 @@ function extractEnhancedError(
return enhanced
}
/**
* Parse and format E2B error message
* Removes E2B-specific line references and adds correct user line numbers
*/
function formatE2BError(
errorMessage: string,
errorOutput: string,
language: CodeLanguage,
userCode: string,
prologueLineCount: number
): { formattedError: string; cleanedOutput: string } {
// Calculate line offset based on language and prologue
const wrapperLines =
language === CodeLanguage.Python ? E2B_PYTHON_WRAPPER_LINES : E2B_JS_WRAPPER_LINES
const totalOffset = prologueLineCount + wrapperLines
let userLine: number | undefined
let cleanErrorType = ''
let cleanErrorMsg = ''
if (language === CodeLanguage.Python) {
// Python error format: "Cell In[X], line Y" followed by error details
// Extract line number from the Cell reference
const cellMatch = errorOutput.match(/Cell In\[\d+\], line (\d+)/)
if (cellMatch) {
const originalLine = Number.parseInt(cellMatch[1], 10)
userLine = originalLine - totalOffset
}
// Extract clean error message from the error string
// Remove file references like "(detected at line X) (file.py, line Y)"
cleanErrorMsg = errorMessage
.replace(/\s*\(detected at line \d+\)/g, '')
.replace(/\s*\([^)]+\.py, line \d+\)/g, '')
.trim()
} else if (language === CodeLanguage.JavaScript) {
// JavaScript error format from E2B: "SyntaxError: /path/file.ts: Message. (line:col)\n\n 9 | ..."
// First, extract the error type and message from the first line
const firstLineEnd = errorMessage.indexOf('\n')
const firstLine = firstLineEnd > 0 ? errorMessage.substring(0, firstLineEnd) : errorMessage
// Parse: "SyntaxError: /home/user/index.ts: Missing semicolon. (11:9)"
const jsErrorMatch = firstLine.match(/^(\w+Error):\s*[^:]+:\s*([^(]+)\.\s*\((\d+):(\d+)\)/)
if (jsErrorMatch) {
cleanErrorType = jsErrorMatch[1]
cleanErrorMsg = jsErrorMatch[2].trim()
const originalLine = Number.parseInt(jsErrorMatch[3], 10)
userLine = originalLine - totalOffset
} else {
// Fallback: look for line number in the arrow pointer line (> 11 |)
const arrowMatch = errorMessage.match(/^>\s*(\d+)\s*\|/m)
if (arrowMatch) {
const originalLine = Number.parseInt(arrowMatch[1], 10)
userLine = originalLine - totalOffset
}
// Try to extract error type and message
const errorMatch = firstLine.match(/^(\w+Error):\s*(.+)/)
if (errorMatch) {
cleanErrorType = errorMatch[1]
cleanErrorMsg = errorMatch[2]
.replace(/^[^:]+:\s*/, '') // Remove file path
.replace(/\s*\(\d+:\d+\)\s*$/, '') // Remove line:col at end
.trim()
} else {
cleanErrorMsg = firstLine
}
}
}
// Build the final clean error message
const finalErrorMsg =
cleanErrorType && cleanErrorMsg
? `${cleanErrorType}: ${cleanErrorMsg}`
: cleanErrorMsg || errorMessage
// Format with line number if available
let formattedError = finalErrorMsg
if (userLine && userLine > 0) {
const codeLines = userCode.split('\n')
// Clamp userLine to the actual user code range
const actualUserLine = Math.min(userLine, codeLines.length)
if (actualUserLine > 0 && actualUserLine <= codeLines.length) {
const lineContent = codeLines[actualUserLine - 1]?.trim()
if (lineContent) {
formattedError = `Line ${actualUserLine}: \`${lineContent}\` - ${finalErrorMsg}`
} else {
formattedError = `Line ${actualUserLine} - ${finalErrorMsg}`
}
}
}
// For stdout, just return the clean error message without the full traceback
const cleanedOutput = finalErrorMsg
return { formattedError, cleanedOutput }
}
/**
* Create a detailed error message for users
*/
@@ -442,6 +546,8 @@ export async function POST(req: NextRequest) {
code,
params = {},
timeout = 5000,
language = DEFAULT_CODE_LANGUAGE,
useLocalVM = false,
envVars = {},
blockData = {},
blockNameMapping = {},
@@ -474,19 +580,164 @@ export async function POST(req: NextRequest) {
resolvedCode = codeResolution.resolvedCode
const contextVariables = codeResolution.contextVariables
const executionMethod = 'vm' // Default execution method
const e2bEnabled = isTruthy(env.E2B_ENABLED)
const lang = isValidCodeLanguage(language) ? language : DEFAULT_CODE_LANGUAGE
const useE2B =
e2bEnabled &&
!useLocalVM &&
!isCustomTool &&
(lang === CodeLanguage.JavaScript || lang === CodeLanguage.Python)
logger.info(`[${requestId}] Using VM for code execution`, {
hasEnvVars: Object.keys(envVars).length > 0,
hasWorkflowVariables: Object.keys(workflowVariables).length > 0,
})
if (useE2B) {
logger.info(`[${requestId}] E2B status`, {
enabled: e2bEnabled,
hasApiKey: Boolean(process.env.E2B_API_KEY),
language: lang,
})
let prologue = ''
const epilogue = ''
// Create a secure context with console logging
if (lang === CodeLanguage.JavaScript) {
// Track prologue lines for error adjustment
let prologueLineCount = 0
prologue += `const params = JSON.parse(${JSON.stringify(JSON.stringify(executionParams))});\n`
prologueLineCount++
prologue += `const environmentVariables = JSON.parse(${JSON.stringify(JSON.stringify(envVars))});\n`
prologueLineCount++
for (const [k, v] of Object.entries(contextVariables)) {
prologue += `const ${k} = JSON.parse(${JSON.stringify(JSON.stringify(v))});\n`
prologueLineCount++
}
const wrapped = [
';(async () => {',
' try {',
' const __sim_result = await (async () => {',
` ${resolvedCode.split('\n').join('\n ')}`,
' })();',
" console.log('__SIM_RESULT__=' + JSON.stringify(__sim_result));",
' } catch (error) {',
' console.log(String((error && (error.stack || error.message)) || error));',
' throw error;',
' }',
'})();',
].join('\n')
const codeForE2B = prologue + wrapped + epilogue
const execStart = Date.now()
const {
result: e2bResult,
stdout: e2bStdout,
sandboxId,
error: e2bError,
} = await executeInE2B({
code: codeForE2B,
language: CodeLanguage.JavaScript,
timeoutMs: timeout,
})
const executionTime = Date.now() - execStart
stdout += e2bStdout
logger.info(`[${requestId}] E2B JS sandbox`, {
sandboxId,
stdoutPreview: e2bStdout?.slice(0, 200),
error: e2bError,
})
// If there was an execution error, format it properly
if (e2bError) {
const { formattedError, cleanedOutput } = formatE2BError(
e2bError,
e2bStdout,
lang,
resolvedCode,
prologueLineCount
)
return NextResponse.json(
{
success: false,
error: formattedError,
output: { result: null, stdout: cleanedOutput, executionTime },
},
{ status: 500 }
)
}
return NextResponse.json({
success: true,
output: { result: e2bResult ?? null, stdout, executionTime },
})
}
// Track prologue lines for error adjustment
let prologueLineCount = 0
prologue += 'import json\n'
prologueLineCount++
prologue += `params = json.loads(${JSON.stringify(JSON.stringify(executionParams))})\n`
prologueLineCount++
prologue += `environmentVariables = json.loads(${JSON.stringify(JSON.stringify(envVars))})\n`
prologueLineCount++
for (const [k, v] of Object.entries(contextVariables)) {
prologue += `${k} = json.loads(${JSON.stringify(JSON.stringify(v))})\n`
prologueLineCount++
}
const wrapped = [
'def __sim_main__():',
...resolvedCode.split('\n').map((l) => ` ${l}`),
'__sim_result__ = __sim_main__()',
"print('__SIM_RESULT__=' + json.dumps(__sim_result__))",
].join('\n')
const codeForE2B = prologue + wrapped + epilogue
const execStart = Date.now()
const {
result: e2bResult,
stdout: e2bStdout,
sandboxId,
error: e2bError,
} = await executeInE2B({
code: codeForE2B,
language: CodeLanguage.Python,
timeoutMs: timeout,
})
const executionTime = Date.now() - execStart
stdout += e2bStdout
logger.info(`[${requestId}] E2B Py sandbox`, {
sandboxId,
stdoutPreview: e2bStdout?.slice(0, 200),
error: e2bError,
})
// If there was an execution error, format it properly
if (e2bError) {
const { formattedError, cleanedOutput } = formatE2BError(
e2bError,
e2bStdout,
lang,
resolvedCode,
prologueLineCount
)
return NextResponse.json(
{
success: false,
error: formattedError,
output: { result: null, stdout: cleanedOutput, executionTime },
},
{ status: 500 }
)
}
return NextResponse.json({
success: true,
output: { result: e2bResult ?? null, stdout, executionTime },
})
}
const executionMethod = 'vm'
const context = createContext({
params: executionParams,
environmentVariables: envVars,
...contextVariables, // Add resolved variables directly to context
fetch: globalThis.fetch || require('node-fetch').default,
...contextVariables,
fetch: (globalThis as any).fetch || require('node-fetch').default,
console: {
log: (...args: any[]) => {
const logMessage = `${args
@@ -504,23 +755,17 @@ export async function POST(req: NextRequest) {
},
})
// Calculate line offset for user code to provide accurate error reporting
const wrapperLines = ['(async () => {', ' try {']
// Add custom tool parameter declarations if needed
if (isCustomTool) {
wrapperLines.push(' // For custom tools, make parameters directly accessible')
Object.keys(executionParams).forEach((key) => {
wrapperLines.push(` const ${key} = params.${key};`)
})
}
userCodeStartLine = wrapperLines.length + 1 // +1 because user code starts on next line
// Build the complete script with proper formatting for line numbers
userCodeStartLine = wrapperLines.length + 1
const fullScript = [
...wrapperLines,
` ${resolvedCode.split('\n').join('\n ')}`, // Indent user code
` ${resolvedCode.split('\n').join('\n ')}`,
' } catch (error) {',
' console.error(error);',
' throw error;',
@@ -529,33 +774,26 @@ export async function POST(req: NextRequest) {
].join('\n')
const script = new Script(fullScript, {
filename: 'user-function.js', // This filename will appear in stack traces
lineOffset: 0, // Start line numbering from 0
columnOffset: 0, // Start column numbering from 0
filename: 'user-function.js',
lineOffset: 0,
columnOffset: 0,
})
const result = await script.runInContext(context, {
timeout,
displayErrors: true,
breakOnSigint: true, // Allow breaking on SIGINT for better debugging
breakOnSigint: true,
})
// }
const executionTime = Date.now() - startTime
logger.info(`[${requestId}] Function executed successfully using ${executionMethod}`, {
executionTime,
})
const response = {
return NextResponse.json({
success: true,
output: {
result,
stdout,
executionTime,
},
}
return NextResponse.json(response)
output: { result, stdout, executionTime },
})
} catch (error: any) {
const executionTime = Date.now() - startTime
logger.error(`[${requestId}] Function execution failed`, {

View File

@@ -1,12 +1,10 @@
import { createHash, randomUUID } from 'crypto'
import { eq, sql } from 'drizzle-orm'
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { deleteChunk, updateChunk } from '@/lib/knowledge/chunks/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkChunkAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, embedding } from '@/db/schema'
const logger = createLogger('ChunkByIdAPI')
@@ -102,33 +100,7 @@ export async function PUT(
try {
const validatedData = UpdateChunkSchema.parse(body)
const updateData: Partial<{
content: string
contentLength: number
tokenCount: number
chunkHash: string
enabled: boolean
updatedAt: Date
}> = {}
if (validatedData.content) {
updateData.content = validatedData.content
updateData.contentLength = validatedData.content.length
// Update token count estimation (rough approximation: 4 chars per token)
updateData.tokenCount = Math.ceil(validatedData.content.length / 4)
updateData.chunkHash = createHash('sha256').update(validatedData.content).digest('hex')
}
if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled
await db.update(embedding).set(updateData).where(eq(embedding.id, chunkId))
// Fetch the updated chunk
const updatedChunk = await db
.select()
.from(embedding)
.where(eq(embedding.id, chunkId))
.limit(1)
const updatedChunk = await updateChunk(chunkId, validatedData, requestId)
logger.info(
`[${requestId}] Chunk updated: ${chunkId} in document ${documentId} in knowledge base ${knowledgeBaseId}`
@@ -136,7 +108,7 @@ export async function PUT(
return NextResponse.json({
success: true,
data: updatedChunk[0],
data: updatedChunk,
})
} catch (validationError) {
if (validationError instanceof z.ZodError) {
@@ -190,37 +162,7 @@ export async function DELETE(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Use transaction to atomically delete chunk and update document statistics
await db.transaction(async (tx) => {
// Get chunk data before deletion for statistics update
const chunkToDelete = await tx
.select({
tokenCount: embedding.tokenCount,
contentLength: embedding.contentLength,
})
.from(embedding)
.where(eq(embedding.id, chunkId))
.limit(1)
if (chunkToDelete.length === 0) {
throw new Error('Chunk not found')
}
const chunk = chunkToDelete[0]
// Delete the chunk
await tx.delete(embedding).where(eq(embedding.id, chunkId))
// Update document statistics
await tx
.update(document)
.set({
chunkCount: sql`${document.chunkCount} - 1`,
tokenCount: sql`${document.tokenCount} - ${chunk.tokenCount}`,
characterCount: sql`${document.characterCount} - ${chunk.contentLength}`,
})
.where(eq(document.id, documentId))
})
await deleteChunk(chunkId, documentId, requestId)
logger.info(
`[${requestId}] Chunk deleted: ${chunkId} from document ${documentId} in knowledge base ${knowledgeBaseId}`

View File

@@ -1,378 +0,0 @@
/**
* Tests for knowledge document chunks API route
*
* @vitest-environment node
*/
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockRequest,
mockAuth,
mockConsoleLogger,
mockDrizzleOrm,
mockKnowledgeSchemas,
} from '@/app/api/__test-utils__/utils'
mockKnowledgeSchemas()
mockDrizzleOrm()
mockConsoleLogger()
vi.mock('@/lib/tokenization/estimators', () => ({
estimateTokenCount: vi.fn().mockReturnValue({ count: 452 }),
}))
vi.mock('@/providers/utils', () => ({
calculateCost: vi.fn().mockReturnValue({
input: 0.00000904,
output: 0,
total: 0.00000904,
pricing: {
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
},
}),
}))
vi.mock('@/app/api/knowledge/utils', () => ({
checkKnowledgeBaseAccess: vi.fn(),
checkKnowledgeBaseWriteAccess: vi.fn(),
checkDocumentAccess: vi.fn(),
checkDocumentWriteAccess: vi.fn(),
checkChunkAccess: vi.fn(),
generateEmbeddings: vi.fn().mockResolvedValue([[0.1, 0.2, 0.3, 0.4, 0.5]]),
processDocumentAsync: vi.fn(),
}))
describe('Knowledge Document Chunks API Route', () => {
const mockAuth$ = mockAuth()
const mockDbChain = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockReturnThis(),
offset: vi.fn().mockReturnThis(),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockResolvedValue(undefined),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
returning: vi.fn().mockResolvedValue([]),
delete: vi.fn().mockReturnThis(),
transaction: vi.fn(),
}
const mockGetUserId = vi.fn()
beforeEach(async () => {
vi.clearAllMocks()
vi.doMock('@/db', () => ({
db: mockDbChain,
}))
vi.doMock('@/app/api/auth/oauth/utils', () => ({
getUserId: mockGetUserId,
}))
Object.values(mockDbChain).forEach((fn) => {
if (typeof fn === 'function' && fn !== mockDbChain.values && fn !== mockDbChain.returning) {
fn.mockClear().mockReturnThis()
}
})
vi.stubGlobal('crypto', {
randomUUID: vi.fn().mockReturnValue('mock-chunk-uuid-1234'),
createHash: vi.fn().mockReturnValue({
update: vi.fn().mockReturnThis(),
digest: vi.fn().mockReturnValue('mock-hash-123'),
}),
})
})
afterEach(() => {
vi.clearAllMocks()
})
describe('POST /api/knowledge/[id]/documents/[documentId]/chunks', () => {
const validChunkData = {
content: 'This is test chunk content for uploading to the knowledge base document.',
enabled: true,
}
const mockDocumentAccess = {
hasAccess: true,
notFound: false,
reason: '',
document: {
id: 'doc-123',
processingStatus: 'completed',
tag1: 'tag1-value',
tag2: 'tag2-value',
tag3: null,
tag4: null,
tag5: null,
tag6: null,
tag7: null,
},
}
const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' })
it('should create chunk successfully with cost tracking', async () => {
const { checkDocumentWriteAccess, generateEmbeddings } = await import(
'@/app/api/knowledge/utils'
)
const { estimateTokenCount } = await import('@/lib/tokenization/estimators')
const { calculateCost } = await import('@/providers/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
...mockDocumentAccess,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
} as any)
// Mock generateEmbeddings
vi.mocked(generateEmbeddings).mockResolvedValue([[0.1, 0.2, 0.3]])
// Mock transaction
const mockTx = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockResolvedValue([{ chunkIndex: 0 }]),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockResolvedValue(undefined),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
}
mockDbChain.transaction.mockImplementation(async (callback) => {
return await callback(mockTx)
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(200)
expect(data.success).toBe(true)
// Verify cost tracking
expect(data.data.cost).toBeDefined()
expect(data.data.cost.input).toBe(0.00000904)
expect(data.data.cost.output).toBe(0)
expect(data.data.cost.total).toBe(0.00000904)
expect(data.data.cost.tokens).toEqual({
prompt: 452,
completion: 0,
total: 452,
})
expect(data.data.cost.model).toBe('text-embedding-3-small')
expect(data.data.cost.pricing).toEqual({
input: 0.02,
output: 0,
updatedAt: '2025-07-10',
})
// Verify function calls
expect(estimateTokenCount).toHaveBeenCalledWith(validChunkData.content, 'openai')
expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 452, 0, false)
})
it('should handle workflow-based authentication', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const workflowData = {
...validChunkData,
workflowId: 'workflow-123',
}
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
...mockDocumentAccess,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
} as any)
const mockTx = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockResolvedValue([]),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockResolvedValue(undefined),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
}
mockDbChain.transaction.mockImplementation(async (callback) => {
return await callback(mockTx)
})
const req = createMockRequest('POST', workflowData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(mockGetUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123')
})
it.concurrent('should return unauthorized for unauthenticated request', async () => {
mockGetUserId.mockResolvedValue(null)
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(401)
expect(data.error).toBe('Unauthorized')
})
it('should return not found for workflow that does not exist', async () => {
const workflowData = {
...validChunkData,
workflowId: 'nonexistent-workflow',
}
mockGetUserId.mockResolvedValue(null)
const req = createMockRequest('POST', workflowData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(404)
expect(data.error).toBe('Workflow not found')
})
it.concurrent('should return not found for document access denied', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
hasAccess: false,
notFound: true,
reason: 'Document not found',
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(404)
expect(data.error).toBe('Document not found')
})
it('should return unauthorized for unauthorized document access', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
hasAccess: false,
notFound: false,
reason: 'Unauthorized access',
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(401)
expect(data.error).toBe('Unauthorized')
})
it('should reject chunks for failed documents', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
...mockDocumentAccess,
document: {
...mockDocumentAccess.document!,
processingStatus: 'failed',
},
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
} as any)
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(400)
expect(data.error).toBe('Cannot add chunks to failed document')
})
it.concurrent('should validate chunk data', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
...mockDocumentAccess,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
} as any)
const invalidData = {
content: '', // Empty content
enabled: true,
}
const req = createMockRequest('POST', invalidData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(400)
expect(data.error).toBe('Invalid request data')
expect(data.details).toBeDefined()
})
it('should inherit tags from parent document', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
mockGetUserId.mockResolvedValue('user-123')
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
...mockDocumentAccess,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
} as any)
const mockTx = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
where: vi.fn().mockReturnThis(),
orderBy: vi.fn().mockReturnThis(),
limit: vi.fn().mockResolvedValue([]),
insert: vi.fn().mockReturnThis(),
values: vi.fn().mockImplementation((data) => {
// Verify that tags are inherited from document
expect(data.tag1).toBe('tag1-value')
expect(data.tag2).toBe('tag2-value')
expect(data.tag3).toBe(null)
return Promise.resolve(undefined)
}),
update: vi.fn().mockReturnThis(),
set: vi.fn().mockReturnThis(),
}
mockDbChain.transaction.mockImplementation(async (callback) => {
return await callback(mockTx)
})
const req = createMockRequest('POST', validChunkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
await POST(req, { params: mockParams })
expect(mockTx.values).toHaveBeenCalled()
})
// REMOVED: "should handle cost calculation with different content lengths" test - it was failing
})
})

View File

@@ -1,18 +1,11 @@
import crypto from 'crypto'
import { and, asc, eq, ilike, inArray, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { batchChunkOperation, createChunk, queryChunks } from '@/lib/knowledge/chunks/service'
import { createLogger } from '@/lib/logs/console/logger'
import { estimateTokenCount } from '@/lib/tokenization/estimators'
import { getUserId } from '@/app/api/auth/oauth/utils'
import {
checkDocumentAccess,
checkDocumentWriteAccess,
generateEmbeddings,
} from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, embedding } from '@/db/schema'
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
import { calculateCost } from '@/providers/utils'
const logger = createLogger('DocumentChunksAPI')
@@ -66,7 +59,6 @@ export async function GET(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if document processing is completed
const doc = accessCheck.document
if (!doc) {
logger.warn(
@@ -89,7 +81,6 @@ export async function GET(
)
}
// Parse query parameters
const { searchParams } = new URL(req.url)
const queryParams = GetChunksQuerySchema.parse({
search: searchParams.get('search') || undefined,
@@ -98,67 +89,12 @@ export async function GET(
offset: searchParams.get('offset') || undefined,
})
// Build query conditions
const conditions = [eq(embedding.documentId, documentId)]
// Add enabled filter
if (queryParams.enabled === 'true') {
conditions.push(eq(embedding.enabled, true))
} else if (queryParams.enabled === 'false') {
conditions.push(eq(embedding.enabled, false))
}
// Add search filter
if (queryParams.search) {
conditions.push(ilike(embedding.content, `%${queryParams.search}%`))
}
// Fetch chunks
const chunks = await db
.select({
id: embedding.id,
chunkIndex: embedding.chunkIndex,
content: embedding.content,
contentLength: embedding.contentLength,
tokenCount: embedding.tokenCount,
enabled: embedding.enabled,
startOffset: embedding.startOffset,
endOffset: embedding.endOffset,
tag1: embedding.tag1,
tag2: embedding.tag2,
tag3: embedding.tag3,
tag4: embedding.tag4,
tag5: embedding.tag5,
tag6: embedding.tag6,
tag7: embedding.tag7,
createdAt: embedding.createdAt,
updatedAt: embedding.updatedAt,
})
.from(embedding)
.where(and(...conditions))
.orderBy(asc(embedding.chunkIndex))
.limit(queryParams.limit)
.offset(queryParams.offset)
// Get total count for pagination
const totalCount = await db
.select({ count: sql`count(*)` })
.from(embedding)
.where(and(...conditions))
logger.info(
`[${requestId}] Retrieved ${chunks.length} chunks for document ${documentId} in knowledge base ${knowledgeBaseId}`
)
const result = await queryChunks(documentId, queryParams, requestId)
return NextResponse.json({
success: true,
data: chunks,
pagination: {
total: Number(totalCount[0]?.count || 0),
limit: queryParams.limit,
offset: queryParams.offset,
hasMore: chunks.length === queryParams.limit,
},
data: result.chunks,
pagination: result.pagination,
})
} catch (error) {
logger.error(`[${requestId}] Error fetching chunks`, error)
@@ -219,76 +155,27 @@ export async function POST(
try {
const validatedData = CreateChunkSchema.parse(searchParams)
// Generate embedding for the content first (outside transaction for performance)
logger.info(`[${requestId}] Generating embedding for manual chunk`)
const embeddings = await generateEmbeddings([validatedData.content])
const docTags = {
tag1: doc.tag1 ?? null,
tag2: doc.tag2 ?? null,
tag3: doc.tag3 ?? null,
tag4: doc.tag4 ?? null,
tag5: doc.tag5 ?? null,
tag6: doc.tag6 ?? null,
tag7: doc.tag7 ?? null,
}
// Calculate accurate token count for both database storage and cost calculation
const tokenCount = estimateTokenCount(validatedData.content, 'openai')
const newChunk = await createChunk(
knowledgeBaseId,
documentId,
docTags,
validatedData,
requestId
)
const chunkId = crypto.randomUUID()
const now = new Date()
// Use transaction to atomically get next index and insert chunk
const newChunk = await db.transaction(async (tx) => {
// Get the next chunk index atomically within the transaction
const lastChunk = await tx
.select({ chunkIndex: embedding.chunkIndex })
.from(embedding)
.where(eq(embedding.documentId, documentId))
.orderBy(sql`${embedding.chunkIndex} DESC`)
.limit(1)
const nextChunkIndex = lastChunk.length > 0 ? lastChunk[0].chunkIndex + 1 : 0
const chunkData = {
id: chunkId,
knowledgeBaseId,
documentId,
chunkIndex: nextChunkIndex,
chunkHash: crypto.createHash('sha256').update(validatedData.content).digest('hex'),
content: validatedData.content,
contentLength: validatedData.content.length,
tokenCount: tokenCount.count, // Use accurate token count
embedding: embeddings[0],
embeddingModel: 'text-embedding-3-small',
startOffset: 0, // Manual chunks don't have document offsets
endOffset: validatedData.content.length,
// Inherit tags from parent document
tag1: doc.tag1,
tag2: doc.tag2,
tag3: doc.tag3,
tag4: doc.tag4,
tag5: doc.tag5,
tag6: doc.tag6,
tag7: doc.tag7,
enabled: validatedData.enabled,
createdAt: now,
updatedAt: now,
}
// Insert the new chunk
await tx.insert(embedding).values(chunkData)
// Update document statistics
await tx
.update(document)
.set({
chunkCount: sql`${document.chunkCount} + 1`,
tokenCount: sql`${document.tokenCount} + ${chunkData.tokenCount}`,
characterCount: sql`${document.characterCount} + ${chunkData.contentLength}`,
})
.where(eq(document.id, documentId))
return chunkData
})
logger.info(`[${requestId}] Manual chunk created: ${chunkId} in document ${documentId}`)
// Calculate cost for the embedding (with fallback if calculation fails)
let cost = null
try {
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
cost = calculateCost('text-embedding-3-small', newChunk.tokenCount, 0, false)
} catch (error) {
logger.warn(`[${requestId}] Failed to calculate cost for chunk upload`, {
error: error instanceof Error ? error.message : 'Unknown error',
@@ -300,6 +187,8 @@ export async function POST(
success: true,
data: {
...newChunk,
documentId,
documentName: doc.filename,
...(cost
? {
cost: {
@@ -307,9 +196,9 @@ export async function POST(
output: cost.output,
total: cost.total,
tokens: {
prompt: tokenCount.count,
prompt: newChunk.tokenCount,
completion: 0,
total: tokenCount.count,
total: newChunk.tokenCount,
},
model: 'text-embedding-3-small',
pricing: cost.pricing,
@@ -371,92 +260,16 @@ export async function PATCH(
const validatedData = BatchOperationSchema.parse(body)
const { operation, chunkIds } = validatedData
logger.info(
`[${requestId}] Starting batch ${operation} operation on ${chunkIds.length} chunks for document ${documentId}`
)
const results = []
let successCount = 0
const errorCount = 0
if (operation === 'delete') {
// Handle batch delete with transaction for consistency
await db.transaction(async (tx) => {
// Get chunks to delete for statistics update
const chunksToDelete = await tx
.select({
id: embedding.id,
tokenCount: embedding.tokenCount,
contentLength: embedding.contentLength,
})
.from(embedding)
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
if (chunksToDelete.length === 0) {
throw new Error('No valid chunks found to delete')
}
// Delete chunks
await tx
.delete(embedding)
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
// Update document statistics
const totalTokens = chunksToDelete.reduce((sum, chunk) => sum + chunk.tokenCount, 0)
const totalCharacters = chunksToDelete.reduce(
(sum, chunk) => sum + chunk.contentLength,
0
)
await tx
.update(document)
.set({
chunkCount: sql`${document.chunkCount} - ${chunksToDelete.length}`,
tokenCount: sql`${document.tokenCount} - ${totalTokens}`,
characterCount: sql`${document.characterCount} - ${totalCharacters}`,
})
.where(eq(document.id, documentId))
successCount = chunksToDelete.length
results.push({
operation: 'delete',
deletedCount: chunksToDelete.length,
chunkIds: chunksToDelete.map((c) => c.id),
})
})
} else {
// Handle batch enable/disable
const enabled = operation === 'enable'
// Update chunks in a single query
const updateResult = await db
.update(embedding)
.set({
enabled,
updatedAt: new Date(),
})
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
.returning({ id: embedding.id })
successCount = updateResult.length
results.push({
operation,
updatedCount: updateResult.length,
chunkIds: updateResult.map((r) => r.id),
})
}
logger.info(
`[${requestId}] Batch ${operation} operation completed: ${successCount} successful, ${errorCount} errors`
)
const result = await batchChunkOperation(documentId, operation, chunkIds, requestId)
return NextResponse.json({
success: true,
data: {
operation,
successCount,
errorCount,
results,
successCount: result.processed,
errorCount: result.errors.length,
processed: result.processed,
errors: result.errors,
},
})
} catch (validationError) {

View File

@@ -24,7 +24,14 @@ vi.mock('@/app/api/knowledge/utils', () => ({
processDocumentAsync: vi.fn(),
}))
// Setup common mocks
vi.mock('@/lib/knowledge/documents/service', () => ({
updateDocument: vi.fn(),
deleteDocument: vi.fn(),
markDocumentAsFailedTimeout: vi.fn(),
retryDocumentProcessing: vi.fn(),
processDocumentAsync: vi.fn(),
}))
mockDrizzleOrm()
mockConsoleLogger()
@@ -42,8 +49,6 @@ describe('Document By ID API Route', () => {
transaction: vi.fn(),
}
// Mock functions will be imported dynamically in tests
const mockDocument = {
id: 'doc-123',
knowledgeBaseId: 'kb-123',
@@ -73,7 +78,6 @@ describe('Document By ID API Route', () => {
}
}
})
// Mock functions are cleared automatically by vitest
}
beforeEach(async () => {
@@ -83,8 +87,6 @@ describe('Document By ID API Route', () => {
db: mockDbChain,
}))
// Utils are mocked at the top level
vi.stubGlobal('crypto', {
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),
})
@@ -195,6 +197,7 @@ describe('Document By ID API Route', () => {
it('should update document successfully', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { updateDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
@@ -203,31 +206,12 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Create a sequence of mocks for the database operations
const updateChain = {
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue(undefined), // Update operation completes
}),
const updatedDocument = {
...mockDocument,
...validUpdateData,
deletedAt: null,
}
const selectChain = {
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi.fn().mockResolvedValue([{ ...mockDocument, ...validUpdateData }]),
}),
}),
}
// Mock transaction
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
update: vi.fn().mockReturnValue(updateChain),
}
await callback(mockTx)
})
// Mock db operations in sequence
mockDbChain.select.mockReturnValue(selectChain)
vi.mocked(updateDocument).mockResolvedValue(updatedDocument)
const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
@@ -238,8 +222,11 @@ describe('Document By ID API Route', () => {
expect(data.success).toBe(true)
expect(data.data.filename).toBe('updated-document.pdf')
expect(data.data.enabled).toBe(false)
expect(mockDbChain.transaction).toHaveBeenCalled()
expect(mockDbChain.select).toHaveBeenCalled()
expect(vi.mocked(updateDocument)).toHaveBeenCalledWith(
'doc-123',
validUpdateData,
expect.any(String)
)
})
it('should validate update data', async () => {
@@ -274,6 +261,7 @@ describe('Document By ID API Route', () => {
it('should mark document as failed due to timeout successfully', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { markDocumentAsFailedTimeout } = await import('@/lib/knowledge/documents/service')
const processingDocument = {
...mockDocument,
@@ -288,34 +276,11 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Create a sequence of mocks for the database operations
const updateChain = {
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue(undefined), // Update operation completes
}),
}
const selectChain = {
from: vi.fn().mockReturnValue({
where: vi.fn().mockReturnValue({
limit: vi
.fn()
.mockResolvedValue([{ ...processingDocument, processingStatus: 'failed' }]),
}),
}),
}
// Mock transaction
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
update: vi.fn().mockReturnValue(updateChain),
}
await callback(mockTx)
vi.mocked(markDocumentAsFailedTimeout).mockResolvedValue({
success: true,
processingDuration: 200000,
})
// Mock db operations in sequence
mockDbChain.select.mockReturnValue(selectChain)
const req = createMockRequest('PUT', { markFailedDueToTimeout: true })
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
const response = await PUT(req, { params: mockParams })
@@ -323,13 +288,13 @@ describe('Document By ID API Route', () => {
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(mockDbChain.transaction).toHaveBeenCalled()
expect(updateChain.set).toHaveBeenCalledWith(
expect.objectContaining({
processingStatus: 'failed',
processingError: 'Processing timed out - background process may have been terminated',
processingCompletedAt: expect.any(Date),
})
expect(data.data.documentId).toBe('doc-123')
expect(data.data.status).toBe('failed')
expect(data.data.message).toBe('Document marked as failed due to timeout')
expect(vi.mocked(markDocumentAsFailedTimeout)).toHaveBeenCalledWith(
'doc-123',
processingDocument.processingStartedAt,
expect.any(String)
)
})
@@ -354,6 +319,7 @@ describe('Document By ID API Route', () => {
it('should reject marking failed for recently started processing', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { markDocumentAsFailedTimeout } = await import('@/lib/knowledge/documents/service')
const recentProcessingDocument = {
...mockDocument,
@@ -368,6 +334,10 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
vi.mocked(markDocumentAsFailedTimeout).mockRejectedValue(
new Error('Document has not been processing long enough to be considered dead')
)
const req = createMockRequest('PUT', { markFailedDueToTimeout: true })
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
const response = await PUT(req, { params: mockParams })
@@ -382,9 +352,8 @@ describe('Document By ID API Route', () => {
const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' })
it('should retry processing successfully', async () => {
const { checkDocumentWriteAccess, processDocumentAsync } = await import(
'@/app/api/knowledge/utils'
)
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { retryDocumentProcessing } = await import('@/lib/knowledge/documents/service')
const failedDocument = {
...mockDocument,
@@ -399,23 +368,12 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock transaction
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
delete: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue(undefined),
}),
update: vi.fn().mockReturnValue({
set: vi.fn().mockReturnValue({
where: vi.fn().mockResolvedValue(undefined),
}),
}),
}
return await callback(mockTx)
vi.mocked(retryDocumentProcessing).mockResolvedValue({
success: true,
status: 'pending',
message: 'Document retry processing started',
})
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
const req = createMockRequest('PUT', { retryProcessing: true })
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
const response = await PUT(req, { params: mockParams })
@@ -425,8 +383,17 @@ describe('Document By ID API Route', () => {
expect(data.success).toBe(true)
expect(data.data.status).toBe('pending')
expect(data.data.message).toBe('Document retry processing started')
expect(mockDbChain.transaction).toHaveBeenCalled()
expect(vi.mocked(processDocumentAsync)).toHaveBeenCalled()
expect(vi.mocked(retryDocumentProcessing)).toHaveBeenCalledWith(
'kb-123',
'doc-123',
{
filename: failedDocument.filename,
fileUrl: failedDocument.fileUrl,
fileSize: failedDocument.fileSize,
mimeType: failedDocument.mimeType,
},
expect.any(String)
)
})
it('should reject retry for non-failed document', async () => {
@@ -486,6 +453,7 @@ describe('Document By ID API Route', () => {
it('should handle database errors during update', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { updateDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
@@ -494,8 +462,7 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock transaction to throw an error
mockDbChain.transaction.mockRejectedValue(new Error('Database error'))
vi.mocked(updateDocument).mockRejectedValue(new Error('Database error'))
const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
@@ -512,6 +479,7 @@ describe('Document By ID API Route', () => {
it('should delete document successfully', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { deleteDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
@@ -520,10 +488,10 @@ describe('Document By ID API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Properly chain the mock database operations for soft delete
mockDbChain.update.mockReturnValue(mockDbChain)
mockDbChain.set.mockReturnValue(mockDbChain)
mockDbChain.where.mockResolvedValue(undefined) // Update operation resolves
vi.mocked(deleteDocument).mockResolvedValue({
success: true,
message: 'Document deleted successfully',
})
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
@@ -533,12 +501,7 @@ describe('Document By ID API Route', () => {
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(data.data.message).toBe('Document deleted successfully')
expect(mockDbChain.update).toHaveBeenCalled()
expect(mockDbChain.set).toHaveBeenCalledWith(
expect.objectContaining({
deletedAt: expect.any(Date),
})
)
expect(vi.mocked(deleteDocument)).toHaveBeenCalledWith('doc-123', expect.any(String))
})
it('should return unauthorized for unauthenticated user', async () => {
@@ -592,6 +555,7 @@ describe('Document By ID API Route', () => {
it('should handle database errors during deletion', async () => {
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
const { deleteDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
@@ -599,7 +563,7 @@ describe('Document By ID API Route', () => {
document: mockDocument,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.set.mockRejectedValue(new Error('Database error'))
vi.mocked(deleteDocument).mockRejectedValue(new Error('Database error'))
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')

View File

@@ -1,16 +1,14 @@
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { TAG_SLOTS } from '@/lib/constants/knowledge'
import { createLogger } from '@/lib/logs/console/logger'
import {
checkDocumentAccess,
checkDocumentWriteAccess,
processDocumentAsync,
} from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, embedding } from '@/db/schema'
deleteDocument,
markDocumentAsFailedTimeout,
retryDocumentProcessing,
updateDocument,
} from '@/lib/knowledge/documents/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
const logger = createLogger('DocumentByIdAPI')
@@ -113,9 +111,7 @@ export async function PUT(
const updateData: any = {}
// Handle special operations first
if (validatedData.markFailedDueToTimeout) {
// Mark document as failed due to timeout (replaces mark-failed endpoint)
const doc = accessCheck.document
if (doc.processingStatus !== 'processing') {
@@ -132,58 +128,30 @@ export async function PUT(
)
}
const now = new Date()
const processingDuration = now.getTime() - new Date(doc.processingStartedAt).getTime()
const DEAD_PROCESS_THRESHOLD_MS = 150 * 1000
try {
await markDocumentAsFailedTimeout(documentId, doc.processingStartedAt, requestId)
if (processingDuration <= DEAD_PROCESS_THRESHOLD_MS) {
return NextResponse.json(
{ error: 'Document has not been processing long enough to be considered dead' },
{ status: 400 }
)
return NextResponse.json({
success: true,
data: {
documentId,
status: 'failed',
message: 'Document marked as failed due to timeout',
},
})
} catch (error) {
if (error instanceof Error) {
return NextResponse.json({ error: error.message }, { status: 400 })
}
throw error
}
updateData.processingStatus = 'failed'
updateData.processingError =
'Processing timed out - background process may have been terminated'
updateData.processingCompletedAt = now
logger.info(
`[${requestId}] Marked document ${documentId} as failed due to dead process (processing time: ${Math.round(processingDuration / 1000)}s)`
)
} else if (validatedData.retryProcessing) {
// Retry processing (replaces retry endpoint)
const doc = accessCheck.document
if (doc.processingStatus !== 'failed') {
return NextResponse.json({ error: 'Document is not in failed state' }, { status: 400 })
}
// Clear existing embeddings and reset document state
await db.transaction(async (tx) => {
await tx.delete(embedding).where(eq(embedding.documentId, documentId))
await tx
.update(document)
.set({
processingStatus: 'pending',
processingStartedAt: null,
processingCompletedAt: null,
processingError: null,
chunkCount: 0,
tokenCount: 0,
characterCount: 0,
})
.where(eq(document.id, documentId))
})
const processingOptions = {
chunkSize: 1024,
minCharactersPerChunk: 24,
recipe: 'default',
lang: 'en',
}
const docData = {
filename: doc.filename,
fileUrl: doc.fileUrl,
@@ -191,80 +159,33 @@ export async function PUT(
mimeType: doc.mimeType,
}
processDocumentAsync(knowledgeBaseId, documentId, docData, processingOptions).catch(
(error: unknown) => {
logger.error(`[${requestId}] Background retry processing error:`, error)
}
const result = await retryDocumentProcessing(
knowledgeBaseId,
documentId,
docData,
requestId
)
logger.info(`[${requestId}] Document retry initiated: ${documentId}`)
return NextResponse.json({
success: true,
data: {
documentId,
status: 'pending',
message: 'Document retry processing started',
status: result.status,
message: result.message,
},
})
} else {
// Regular field updates
if (validatedData.filename !== undefined) updateData.filename = validatedData.filename
if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled
if (validatedData.chunkCount !== undefined) updateData.chunkCount = validatedData.chunkCount
if (validatedData.tokenCount !== undefined) updateData.tokenCount = validatedData.tokenCount
if (validatedData.characterCount !== undefined)
updateData.characterCount = validatedData.characterCount
if (validatedData.processingStatus !== undefined)
updateData.processingStatus = validatedData.processingStatus
if (validatedData.processingError !== undefined)
updateData.processingError = validatedData.processingError
const updatedDocument = await updateDocument(documentId, validatedData, requestId)
// Tag field updates
TAG_SLOTS.forEach((slot) => {
if ((validatedData as any)[slot] !== undefined) {
;(updateData as any)[slot] = (validatedData as any)[slot]
}
logger.info(
`[${requestId}] Document updated: ${documentId} in knowledge base ${knowledgeBaseId}`
)
return NextResponse.json({
success: true,
data: updatedDocument,
})
}
await db.transaction(async (tx) => {
// Update the document
await tx.update(document).set(updateData).where(eq(document.id, documentId))
// If any tag fields were updated, also update the embeddings
const hasTagUpdates = TAG_SLOTS.some((field) => (validatedData as any)[field] !== undefined)
if (hasTagUpdates) {
const embeddingUpdateData: Record<string, string | null> = {}
TAG_SLOTS.forEach((field) => {
if ((validatedData as any)[field] !== undefined) {
embeddingUpdateData[field] = (validatedData as any)[field] || null
}
})
await tx
.update(embedding)
.set(embeddingUpdateData)
.where(eq(embedding.documentId, documentId))
}
})
// Fetch the updated document
const updatedDocument = await db
.select()
.from(document)
.where(eq(document.id, documentId))
.limit(1)
logger.info(
`[${requestId}] Document updated: ${documentId} in knowledge base ${knowledgeBaseId}`
)
return NextResponse.json({
success: true,
data: updatedDocument[0],
})
} catch (validationError) {
if (validationError instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid document update data`, {
@@ -313,13 +234,7 @@ export async function DELETE(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Soft delete by setting deletedAt timestamp
await db
.update(document)
.set({
deletedAt: new Date(),
})
.where(eq(document.id, documentId))
const result = await deleteDocument(documentId, requestId)
logger.info(
`[${requestId}] Document deleted: ${documentId} from knowledge base ${knowledgeBaseId}`
@@ -327,7 +242,7 @@ export async function DELETE(
return NextResponse.json({
success: true,
data: { message: 'Document deleted successfully' },
data: result,
})
} catch (error) {
logger.error(`[${requestId}] Error deleting document`, error)

View File

@@ -1,17 +1,17 @@
import { randomUUID } from 'crypto'
import { and, eq, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { SUPPORTED_FIELD_TYPES } from '@/lib/knowledge/consts'
import {
getMaxSlotsForFieldType,
getSlotsForFieldType,
SUPPORTED_FIELD_TYPES,
} from '@/lib/constants/knowledge'
cleanupUnusedTagDefinitions,
createOrUpdateTagDefinitionsBulk,
deleteAllTagDefinitions,
getDocumentTagDefinitions,
} from '@/lib/knowledge/tags/service'
import type { BulkTagDefinitionsData } from '@/lib/knowledge/tags/types'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
export const dynamic = 'force-dynamic'
@@ -29,106 +29,6 @@ const BulkTagDefinitionsSchema = z.object({
definitions: z.array(TagDefinitionSchema),
})
// Helper function to get the next available slot for a knowledge base and field type
async function getNextAvailableSlot(
knowledgeBaseId: string,
fieldType: string,
existingBySlot?: Map<string, any>
): Promise<string | null> {
// Get available slots for this field type
const availableSlots = getSlotsForFieldType(fieldType)
let usedSlots: Set<string>
if (existingBySlot) {
// Use provided map if available (for performance in batch operations)
// Filter by field type
usedSlots = new Set(
Array.from(existingBySlot.entries())
.filter(([_, def]) => def.fieldType === fieldType)
.map(([slot, _]) => slot)
)
} else {
// Query database for existing tag definitions of the same field type
const existingDefinitions = await db
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
)
)
usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot))
}
// Find the first available slot for this field type
for (const slot of availableSlots) {
if (!usedSlots.has(slot)) {
return slot
}
}
return null // No available slots for this field type
}
// Helper function to clean up unused tag definitions
async function cleanupUnusedTagDefinitions(knowledgeBaseId: string, requestId: string) {
try {
logger.info(`[${requestId}] Starting cleanup for KB ${knowledgeBaseId}`)
// Get all tag definitions for this KB
const allDefinitions = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
logger.info(`[${requestId}] Found ${allDefinitions.length} tag definitions to check`)
if (allDefinitions.length === 0) {
return 0
}
let cleanedCount = 0
// For each tag definition, check if any documents use that tag slot
for (const definition of allDefinitions) {
const slot = definition.tagSlot
// Use raw SQL with proper column name injection
const countResult = await db.execute(sql`
SELECT count(*) as count
FROM document
WHERE knowledge_base_id = ${knowledgeBaseId}
AND ${sql.raw(slot)} IS NOT NULL
AND trim(${sql.raw(slot)}) != ''
`)
const count = Number(countResult[0]?.count) || 0
logger.info(
`[${requestId}] Tag ${definition.displayName} (${slot}): ${count} documents using it`
)
// If count is 0, remove this tag definition
if (count === 0) {
await db
.delete(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.id, definition.id))
cleanedCount++
logger.info(
`[${requestId}] Removed unused tag definition: ${definition.displayName} (${definition.tagSlot})`
)
}
}
return cleanedCount
} catch (error) {
logger.warn(`[${requestId}] Failed to cleanup unused tag definitions:`, error)
return 0 // Don't fail the main operation if cleanup fails
}
}
// GET /api/knowledge/[id]/documents/[documentId]/tag-definitions - Get tag definitions for a document
export async function GET(
req: NextRequest,
@@ -145,35 +45,22 @@ export async function GET(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Verify document exists and belongs to the knowledge base
const documentExists = await db
.select({ id: document.id })
.from(document)
.where(and(eq(document.id, documentId), eq(document.knowledgeBaseId, knowledgeBaseId)))
.limit(1)
if (documentExists.length === 0) {
return NextResponse.json({ error: 'Document not found' }, { status: 404 })
const accessCheck = await checkDocumentAccess(knowledgeBaseId, documentId, session.user.id)
if (!accessCheck.hasAccess) {
if (accessCheck.notFound) {
logger.warn(
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
)
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
}
logger.warn(
`[${requestId}] User ${session.user.id} attempted unauthorized document access: ${accessCheck.reason}`
)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Get tag definitions for the knowledge base
const tagDefinitions = await db
.select({
id: knowledgeBaseTagDefinitions.id,
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
fieldType: knowledgeBaseTagDefinitions.fieldType,
createdAt: knowledgeBaseTagDefinitions.createdAt,
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
})
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
const tagDefinitions = await getDocumentTagDefinitions(knowledgeBaseId)
logger.info(`[${requestId}] Retrieved ${tagDefinitions.length} tag definitions`)
@@ -203,21 +90,19 @@ export async function POST(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has write access to the knowledge base
const accessCheck = await checkKnowledgeBaseWriteAccess(knowledgeBaseId, session.user.id)
// Verify document exists and user has write access
const accessCheck = await checkDocumentWriteAccess(knowledgeBaseId, documentId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Verify document exists and belongs to the knowledge base
const documentExists = await db
.select({ id: document.id })
.from(document)
.where(and(eq(document.id, documentId), eq(document.knowledgeBaseId, knowledgeBaseId)))
.limit(1)
if (documentExists.length === 0) {
return NextResponse.json({ error: 'Document not found' }, { status: 404 })
if (accessCheck.notFound) {
logger.warn(
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
)
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
}
logger.warn(
`[${requestId}] User ${session.user.id} attempted unauthorized document write access: ${accessCheck.reason}`
)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
let body
@@ -238,197 +123,24 @@ export async function POST(
const validatedData = BulkTagDefinitionsSchema.parse(body)
// Validate slots are valid for their field types
for (const definition of validatedData.definitions) {
const validSlots = getSlotsForFieldType(definition.fieldType)
if (validSlots.length === 0) {
return NextResponse.json(
{ error: `Unsupported field type: ${definition.fieldType}` },
{ status: 400 }
)
}
if (!validSlots.includes(definition.tagSlot)) {
return NextResponse.json(
{
error: `Invalid slot '${definition.tagSlot}' for field type '${definition.fieldType}'. Valid slots: ${validSlots.join(', ')}`,
},
{ status: 400 }
)
}
const bulkData: BulkTagDefinitionsData = {
definitions: validatedData.definitions.map((def) => ({
tagSlot: def.tagSlot,
displayName: def.displayName,
fieldType: def.fieldType,
originalDisplayName: def._originalDisplayName,
})),
}
// Validate no duplicate tag slots within the same field type
const slotsByFieldType = new Map<string, Set<string>>()
for (const definition of validatedData.definitions) {
if (!slotsByFieldType.has(definition.fieldType)) {
slotsByFieldType.set(definition.fieldType, new Set())
}
const slotsForType = slotsByFieldType.get(definition.fieldType)!
if (slotsForType.has(definition.tagSlot)) {
return NextResponse.json(
{
error: `Duplicate slot '${definition.tagSlot}' for field type '${definition.fieldType}'`,
},
{ status: 400 }
)
}
slotsForType.add(definition.tagSlot)
}
const now = new Date()
const createdDefinitions: (typeof knowledgeBaseTagDefinitions.$inferSelect)[] = []
// Get existing definitions
const existingDefinitions = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
// Group by field type for validation
const existingByFieldType = new Map<string, number>()
for (const def of existingDefinitions) {
existingByFieldType.set(def.fieldType, (existingByFieldType.get(def.fieldType) || 0) + 1)
}
// Validate we don't exceed limits per field type
const newByFieldType = new Map<string, number>()
for (const definition of validatedData.definitions) {
// Skip validation for edit operations - they don't create new slots
if (definition._originalDisplayName) {
continue
}
const existingTagNames = new Set(
existingDefinitions
.filter((def) => def.fieldType === definition.fieldType)
.map((def) => def.displayName)
)
if (!existingTagNames.has(definition.displayName)) {
newByFieldType.set(
definition.fieldType,
(newByFieldType.get(definition.fieldType) || 0) + 1
)
}
}
for (const [fieldType, newCount] of newByFieldType.entries()) {
const existingCount = existingByFieldType.get(fieldType) || 0
const maxSlots = getMaxSlotsForFieldType(fieldType)
if (existingCount + newCount > maxSlots) {
return NextResponse.json(
{
error: `Cannot create ${newCount} new '${fieldType}' tags. Knowledge base already has ${existingCount} '${fieldType}' tag definitions. Maximum is ${maxSlots} per field type.`,
},
{ status: 400 }
)
}
}
// Use transaction to ensure consistency
await db.transaction(async (tx) => {
// Create maps for lookups
const existingByName = new Map(existingDefinitions.map((def) => [def.displayName, def]))
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot, def]))
// Process each definition
for (const definition of validatedData.definitions) {
if (definition._originalDisplayName) {
// This is an EDIT operation - find by original name and update
const originalDefinition = existingByName.get(definition._originalDisplayName)
if (originalDefinition) {
logger.info(
`[${requestId}] Editing tag definition: ${definition._originalDisplayName} -> ${definition.displayName} (slot ${originalDefinition.tagSlot})`
)
await tx
.update(knowledgeBaseTagDefinitions)
.set({
displayName: definition.displayName,
fieldType: definition.fieldType,
updatedAt: now,
})
.where(eq(knowledgeBaseTagDefinitions.id, originalDefinition.id))
createdDefinitions.push({
...originalDefinition,
displayName: definition.displayName,
fieldType: definition.fieldType,
updatedAt: now,
})
continue
}
logger.warn(
`[${requestId}] Could not find original definition for: ${definition._originalDisplayName}`
)
}
// Regular create/update logic
const existingByDisplayName = existingByName.get(definition.displayName)
if (existingByDisplayName) {
// Display name exists - UPDATE operation
logger.info(
`[${requestId}] Updating existing tag definition: ${definition.displayName} (slot ${existingByDisplayName.tagSlot})`
)
await tx
.update(knowledgeBaseTagDefinitions)
.set({
fieldType: definition.fieldType,
updatedAt: now,
})
.where(eq(knowledgeBaseTagDefinitions.id, existingByDisplayName.id))
createdDefinitions.push({
...existingByDisplayName,
fieldType: definition.fieldType,
updatedAt: now,
})
} else {
// Display name doesn't exist - CREATE operation
const targetSlot = await getNextAvailableSlot(
knowledgeBaseId,
definition.fieldType,
existingBySlot
)
if (!targetSlot) {
logger.error(
`[${requestId}] No available slots for new tag definition: ${definition.displayName}`
)
continue
}
logger.info(
`[${requestId}] Creating new tag definition: ${definition.displayName} -> ${targetSlot}`
)
const newDefinition = {
id: randomUUID(),
knowledgeBaseId,
tagSlot: targetSlot as any,
displayName: definition.displayName,
fieldType: definition.fieldType,
createdAt: now,
updatedAt: now,
}
await tx.insert(knowledgeBaseTagDefinitions).values(newDefinition)
existingBySlot.set(targetSlot as any, newDefinition)
createdDefinitions.push(newDefinition as any)
}
}
})
logger.info(`[${requestId}] Created/updated ${createdDefinitions.length} tag definitions`)
const result = await createOrUpdateTagDefinitionsBulk(knowledgeBaseId, bulkData, requestId)
return NextResponse.json({
success: true,
data: createdDefinitions,
data: {
created: result.created,
updated: result.updated,
errors: result.errors,
},
})
} catch (error) {
if (error instanceof z.ZodError) {
@@ -459,10 +171,19 @@ export async function DELETE(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has write access to the knowledge base
const accessCheck = await checkKnowledgeBaseWriteAccess(knowledgeBaseId, session.user.id)
// Verify document exists and user has write access
const accessCheck = await checkDocumentWriteAccess(knowledgeBaseId, documentId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
if (accessCheck.notFound) {
logger.warn(
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
)
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
}
logger.warn(
`[${requestId}] User ${session.user.id} attempted unauthorized document write access: ${accessCheck.reason}`
)
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
if (action === 'cleanup') {
@@ -478,13 +199,12 @@ export async function DELETE(
// Delete all tag definitions (original behavior)
logger.info(`[${requestId}] Deleting all tag definitions for KB ${knowledgeBaseId}`)
const result = await db
.delete(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
const deletedCount = await deleteAllTagDefinitions(knowledgeBaseId, requestId)
return NextResponse.json({
success: true,
message: 'Tag definitions deleted successfully',
data: { deleted: deletedCount },
})
} catch (error) {
logger.error(`[${requestId}] Error with tag definitions operation`, error)

View File

@@ -24,6 +24,19 @@ vi.mock('@/app/api/knowledge/utils', () => ({
processDocumentAsync: vi.fn(),
}))
vi.mock('@/lib/knowledge/documents/service', () => ({
getDocuments: vi.fn(),
createSingleDocument: vi.fn(),
createDocumentRecords: vi.fn(),
processDocumentsWithQueue: vi.fn(),
getProcessingConfig: vi.fn(),
bulkDocumentOperation: vi.fn(),
updateDocument: vi.fn(),
deleteDocument: vi.fn(),
markDocumentAsFailedTimeout: vi.fn(),
retryDocumentProcessing: vi.fn(),
}))
mockDrizzleOrm()
mockConsoleLogger()
@@ -72,7 +85,6 @@ describe('Knowledge Base Documents API Route', () => {
}
}
})
// Clear all mocks - they will be set up in individual tests
}
beforeEach(async () => {
@@ -96,6 +108,7 @@ describe('Knowledge Base Documents API Route', () => {
it('should retrieve documents successfully for authenticated user', async () => {
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
const { getDocuments } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
@@ -103,11 +116,15 @@ describe('Knowledge Base Documents API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock the count query (first query)
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
// Mock the documents query (second query)
mockDbChain.offset.mockResolvedValue([mockDocument])
vi.mocked(getDocuments).mockResolvedValue({
documents: [mockDocument],
pagination: {
total: 1,
limit: 50,
offset: 0,
hasMore: false,
},
})
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
@@ -118,12 +135,22 @@ describe('Knowledge Base Documents API Route', () => {
expect(data.success).toBe(true)
expect(data.data.documents).toHaveLength(1)
expect(data.data.documents[0].id).toBe('doc-123')
expect(mockDbChain.select).toHaveBeenCalled()
expect(vi.mocked(checkKnowledgeBaseAccess)).toHaveBeenCalledWith('kb-123', 'user-123')
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
'kb-123',
{
includeDisabled: false,
search: undefined,
limit: 50,
offset: 0,
},
expect.any(String)
)
})
it('should filter disabled documents by default', async () => {
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
const { getDocuments } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
@@ -131,22 +158,36 @@ describe('Knowledge Base Documents API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock the count query (first query)
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
// Mock the documents query (second query)
mockDbChain.offset.mockResolvedValue([mockDocument])
vi.mocked(getDocuments).mockResolvedValue({
documents: [mockDocument],
pagination: {
total: 1,
limit: 50,
offset: 0,
hasMore: false,
},
})
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
const response = await GET(req, { params: mockParams })
expect(response.status).toBe(200)
expect(mockDbChain.where).toHaveBeenCalled()
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
'kb-123',
{
includeDisabled: false,
search: undefined,
limit: 50,
offset: 0,
},
expect.any(String)
)
})
it('should include disabled documents when requested', async () => {
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
const { getDocuments } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
@@ -154,11 +195,15 @@ describe('Knowledge Base Documents API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock the count query (first query)
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
// Mock the documents query (second query)
mockDbChain.offset.mockResolvedValue([mockDocument])
vi.mocked(getDocuments).mockResolvedValue({
documents: [mockDocument],
pagination: {
total: 1,
limit: 50,
offset: 0,
hasMore: false,
},
})
const url = 'http://localhost:3000/api/knowledge/kb-123/documents?includeDisabled=true'
const req = new Request(url, { method: 'GET' }) as any
@@ -167,6 +212,16 @@ describe('Knowledge Base Documents API Route', () => {
const response = await GET(req, { params: mockParams })
expect(response.status).toBe(200)
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
'kb-123',
{
includeDisabled: true,
search: undefined,
limit: 50,
offset: 0,
},
expect.any(String)
)
})
it('should return unauthorized for unauthenticated user', async () => {
@@ -216,13 +271,14 @@ describe('Knowledge Base Documents API Route', () => {
it('should handle database errors', async () => {
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
const { getDocuments } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.orderBy.mockRejectedValue(new Error('Database error'))
vi.mocked(getDocuments).mockRejectedValue(new Error('Database error'))
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
@@ -245,13 +301,35 @@ describe('Knowledge Base Documents API Route', () => {
it('should create single document successfully', async () => {
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
const { createSingleDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.values.mockResolvedValue(undefined)
const createdDocument = {
id: 'doc-123',
knowledgeBaseId: 'kb-123',
filename: validDocumentData.filename,
fileUrl: validDocumentData.fileUrl,
fileSize: validDocumentData.fileSize,
mimeType: validDocumentData.mimeType,
chunkCount: 0,
tokenCount: 0,
characterCount: 0,
enabled: true,
uploadedAt: new Date(),
tag1: null,
tag2: null,
tag3: null,
tag4: null,
tag5: null,
tag6: null,
tag7: null,
}
vi.mocked(createSingleDocument).mockResolvedValue(createdDocument)
const req = createMockRequest('POST', validDocumentData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
@@ -262,7 +340,11 @@ describe('Knowledge Base Documents API Route', () => {
expect(data.success).toBe(true)
expect(data.data.filename).toBe(validDocumentData.filename)
expect(data.data.fileUrl).toBe(validDocumentData.fileUrl)
expect(mockDbChain.insert).toHaveBeenCalled()
expect(vi.mocked(createSingleDocument)).toHaveBeenCalledWith(
validDocumentData,
'kb-123',
expect.any(String)
)
})
it('should validate single document data', async () => {
@@ -320,9 +402,9 @@ describe('Knowledge Base Documents API Route', () => {
}
it('should create bulk documents successfully', async () => {
const { checkKnowledgeBaseWriteAccess, processDocumentAsync } = await import(
'@/app/api/knowledge/utils'
)
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
const { createDocumentRecords, processDocumentsWithQueue, getProcessingConfig } =
await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
@@ -330,17 +412,31 @@ describe('Knowledge Base Documents API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock transaction to return the created documents
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
insert: vi.fn().mockReturnValue({
values: vi.fn().mockResolvedValue(undefined),
}),
}
return await callback(mockTx)
})
const createdDocuments = [
{
documentId: 'doc-1',
filename: 'doc1.pdf',
fileUrl: 'https://example.com/doc1.pdf',
fileSize: 1024,
mimeType: 'application/pdf',
},
{
documentId: 'doc-2',
filename: 'doc2.pdf',
fileUrl: 'https://example.com/doc2.pdf',
fileSize: 2048,
mimeType: 'application/pdf',
},
]
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
vi.mocked(createDocumentRecords).mockResolvedValue(createdDocuments)
vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined)
vi.mocked(getProcessingConfig).mockReturnValue({
maxConcurrentDocuments: 8,
batchSize: 20,
delayBetweenBatches: 100,
delayBetweenDocuments: 0,
})
const req = createMockRequest('POST', validBulkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
@@ -352,7 +448,12 @@ describe('Knowledge Base Documents API Route', () => {
expect(data.data.total).toBe(2)
expect(data.data.documentsCreated).toHaveLength(2)
expect(data.data.processingMethod).toBe('background')
expect(mockDbChain.transaction).toHaveBeenCalled()
expect(vi.mocked(createDocumentRecords)).toHaveBeenCalledWith(
validBulkData.documents,
'kb-123',
expect.any(String)
)
expect(vi.mocked(processDocumentsWithQueue)).toHaveBeenCalled()
})
it('should validate bulk document data', async () => {
@@ -394,9 +495,9 @@ describe('Knowledge Base Documents API Route', () => {
})
it('should handle processing errors gracefully', async () => {
const { checkKnowledgeBaseWriteAccess, processDocumentAsync } = await import(
'@/app/api/knowledge/utils'
)
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
const { createDocumentRecords, processDocumentsWithQueue, getProcessingConfig } =
await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
@@ -404,26 +505,30 @@ describe('Knowledge Base Documents API Route', () => {
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
// Mock transaction to succeed but processing to fail
mockDbChain.transaction.mockImplementation(async (callback) => {
const mockTx = {
insert: vi.fn().mockReturnValue({
values: vi.fn().mockResolvedValue(undefined),
}),
}
return await callback(mockTx)
})
const createdDocuments = [
{
documentId: 'doc-1',
filename: 'doc1.pdf',
fileUrl: 'https://example.com/doc1.pdf',
fileSize: 1024,
mimeType: 'application/pdf',
},
]
// Don't reject the promise - the processing is async and catches errors internally
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
vi.mocked(createDocumentRecords).mockResolvedValue(createdDocuments)
vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined)
vi.mocked(getProcessingConfig).mockReturnValue({
maxConcurrentDocuments: 8,
batchSize: 20,
delayBetweenBatches: 100,
delayBetweenDocuments: 0,
})
const req = createMockRequest('POST', validBulkData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
const response = await POST(req, { params: mockParams })
const data = await response.json()
// The endpoint should still return success since documents are created
// and processing happens asynchronously
expect(response.status).toBe(200)
expect(data.success).toBe(true)
})
@@ -485,13 +590,14 @@ describe('Knowledge Base Documents API Route', () => {
it('should handle database errors during creation', async () => {
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
const { createSingleDocument } = await import('@/lib/knowledge/documents/service')
mockAuth$.mockAuthenticatedUser()
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.values.mockRejectedValue(new Error('Database error'))
vi.mocked(createSingleDocument).mockRejectedValue(new Error('Database error'))
const req = createMockRequest('POST', validDocumentData)
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')

View File

@@ -1,279 +1,22 @@
import { randomUUID } from 'crypto'
import { and, desc, eq, inArray, isNull, sql } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { getSlotsForFieldType } from '@/lib/constants/knowledge'
import {
bulkDocumentOperation,
createDocumentRecords,
createSingleDocument,
getDocuments,
getProcessingConfig,
processDocumentsWithQueue,
} from '@/lib/knowledge/documents/service'
import type { DocumentSortField, SortOrder } from '@/lib/knowledge/documents/types'
import { createLogger } from '@/lib/logs/console/logger'
import { getUserId } from '@/app/api/auth/oauth/utils'
import {
checkKnowledgeBaseAccess,
checkKnowledgeBaseWriteAccess,
processDocumentAsync,
} from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
const logger = createLogger('DocumentsAPI')
const PROCESSING_CONFIG = {
maxConcurrentDocuments: 3,
batchSize: 5,
delayBetweenBatches: 1000,
delayBetweenDocuments: 500,
}
// Helper function to get the next available slot for a knowledge base and field type
async function getNextAvailableSlot(
knowledgeBaseId: string,
fieldType: string,
existingBySlot?: Map<string, any>
): Promise<string | null> {
let usedSlots: Set<string>
if (existingBySlot) {
// Use provided map if available (for performance in batch operations)
// Filter by field type
usedSlots = new Set(
Array.from(existingBySlot.entries())
.filter(([_, def]) => def.fieldType === fieldType)
.map(([slot, _]) => slot)
)
} else {
// Query database for existing tag definitions of the same field type
const existingDefinitions = await db
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
)
)
usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot))
}
// Find the first available slot for this field type
const availableSlots = getSlotsForFieldType(fieldType)
for (const slot of availableSlots) {
if (!usedSlots.has(slot)) {
return slot
}
}
return null // No available slots for this field type
}
// Helper function to process structured document tags
async function processDocumentTags(
knowledgeBaseId: string,
tagData: Array<{ tagName: string; fieldType: string; value: string }>,
requestId: string
): Promise<Record<string, string | null>> {
const result: Record<string, string | null> = {}
// Initialize all text tag slots to null (only text type is supported currently)
const textSlots = getSlotsForFieldType('text')
textSlots.forEach((slot) => {
result[slot] = null
})
if (!Array.isArray(tagData) || tagData.length === 0) {
return result
}
try {
// Get existing tag definitions
const existingDefinitions = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
const existingByName = new Map(existingDefinitions.map((def) => [def.displayName, def]))
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot, def]))
// Process each tag
for (const tag of tagData) {
if (!tag.tagName?.trim() || !tag.value?.trim()) continue
const tagName = tag.tagName.trim()
const fieldType = tag.fieldType
const value = tag.value.trim()
let targetSlot: string | null = null
// Check if tag definition already exists
const existingDef = existingByName.get(tagName)
if (existingDef) {
targetSlot = existingDef.tagSlot
} else {
// Find next available slot using the helper function
targetSlot = await getNextAvailableSlot(knowledgeBaseId, fieldType, existingBySlot)
// Create new tag definition if we have a slot
if (targetSlot) {
const newDefinition = {
id: randomUUID(),
knowledgeBaseId,
tagSlot: targetSlot as any,
displayName: tagName,
fieldType,
createdAt: new Date(),
updatedAt: new Date(),
}
await db.insert(knowledgeBaseTagDefinitions).values(newDefinition)
existingBySlot.set(targetSlot as any, newDefinition)
logger.info(`[${requestId}] Created tag definition: ${tagName} -> ${targetSlot}`)
}
}
// Assign value to the slot
if (targetSlot) {
result[targetSlot] = value
}
}
return result
} catch (error) {
logger.error(`[${requestId}] Error processing document tags:`, error)
return result
}
}
async function processDocumentsWithConcurrencyControl(
createdDocuments: Array<{
documentId: string
filename: string
fileUrl: string
fileSize: number
mimeType: string
}>,
knowledgeBaseId: string,
processingOptions: {
chunkSize: number
minCharactersPerChunk: number
recipe: string
lang: string
chunkOverlap: number
},
requestId: string
): Promise<void> {
const totalDocuments = createdDocuments.length
const batches = []
for (let i = 0; i < totalDocuments; i += PROCESSING_CONFIG.batchSize) {
batches.push(createdDocuments.slice(i, i + PROCESSING_CONFIG.batchSize))
}
logger.info(`[${requestId}] Processing ${totalDocuments} documents in ${batches.length} batches`)
for (const [batchIndex, batch] of batches.entries()) {
logger.info(
`[${requestId}] Starting batch ${batchIndex + 1}/${batches.length} with ${batch.length} documents`
)
await processBatchWithConcurrency(batch, knowledgeBaseId, processingOptions, requestId)
if (batchIndex < batches.length - 1) {
await new Promise((resolve) => setTimeout(resolve, PROCESSING_CONFIG.delayBetweenBatches))
}
}
logger.info(`[${requestId}] Completed processing initiation for all ${totalDocuments} documents`)
}
async function processBatchWithConcurrency(
batch: Array<{
documentId: string
filename: string
fileUrl: string
fileSize: number
mimeType: string
}>,
knowledgeBaseId: string,
processingOptions: {
chunkSize: number
minCharactersPerChunk: number
recipe: string
lang: string
chunkOverlap: number
},
requestId: string
): Promise<void> {
const semaphore = new Array(PROCESSING_CONFIG.maxConcurrentDocuments).fill(0)
const processingPromises = batch.map(async (doc, index) => {
if (index > 0) {
await new Promise((resolve) =>
setTimeout(resolve, index * PROCESSING_CONFIG.delayBetweenDocuments)
)
}
await new Promise<void>((resolve) => {
const checkSlot = () => {
const availableIndex = semaphore.findIndex((slot) => slot === 0)
if (availableIndex !== -1) {
semaphore[availableIndex] = 1
resolve()
} else {
setTimeout(checkSlot, 100)
}
}
checkSlot()
})
try {
logger.info(`[${requestId}] Starting processing for document: ${doc.filename}`)
await processDocumentAsync(
knowledgeBaseId,
doc.documentId,
{
filename: doc.filename,
fileUrl: doc.fileUrl,
fileSize: doc.fileSize,
mimeType: doc.mimeType,
},
processingOptions
)
logger.info(`[${requestId}] Successfully initiated processing for document: ${doc.filename}`)
} catch (error: unknown) {
logger.error(`[${requestId}] Failed to process document: ${doc.filename}`, {
documentId: doc.documentId,
filename: doc.filename,
error: error instanceof Error ? error.message : 'Unknown error',
})
try {
await db
.update(document)
.set({
processingStatus: 'failed',
processingError:
error instanceof Error ? error.message : 'Failed to initiate processing',
processingCompletedAt: new Date(),
})
.where(eq(document.id, doc.documentId))
} catch (dbError: unknown) {
logger.error(
`[${requestId}] Failed to update document status for failed document: ${doc.documentId}`,
dbError
)
}
} finally {
const slotIndex = semaphore.findIndex((slot) => slot === 1)
if (slotIndex !== -1) {
semaphore[slotIndex] = 0
}
}
})
await Promise.allSettled(processingPromises)
}
const CreateDocumentSchema = z.object({
filename: z.string().min(1, 'Filename is required'),
fileUrl: z.string().url('File URL must be valid'),
@@ -337,83 +80,50 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
const url = new URL(req.url)
const includeDisabled = url.searchParams.get('includeDisabled') === 'true'
const search = url.searchParams.get('search')
const search = url.searchParams.get('search') || undefined
const limit = Number.parseInt(url.searchParams.get('limit') || '50')
const offset = Number.parseInt(url.searchParams.get('offset') || '0')
const sortByParam = url.searchParams.get('sortBy')
const sortOrderParam = url.searchParams.get('sortOrder')
// Build where conditions
const whereConditions = [
eq(document.knowledgeBaseId, knowledgeBaseId),
isNull(document.deletedAt),
// Validate sort parameters
const validSortFields: DocumentSortField[] = [
'filename',
'fileSize',
'tokenCount',
'chunkCount',
'uploadedAt',
'processingStatus',
]
const validSortOrders: SortOrder[] = ['asc', 'desc']
// Filter out disabled documents unless specifically requested
if (!includeDisabled) {
whereConditions.push(eq(document.enabled, true))
}
const sortBy =
sortByParam && validSortFields.includes(sortByParam as DocumentSortField)
? (sortByParam as DocumentSortField)
: undefined
const sortOrder =
sortOrderParam && validSortOrders.includes(sortOrderParam as SortOrder)
? (sortOrderParam as SortOrder)
: undefined
// Add search condition if provided
if (search) {
whereConditions.push(
// Search in filename
sql`LOWER(${document.filename}) LIKE LOWER(${`%${search}%`})`
)
}
// Get total count for pagination
const totalResult = await db
.select({ count: sql<number>`COUNT(*)` })
.from(document)
.where(and(...whereConditions))
const total = totalResult[0]?.count || 0
const hasMore = offset + limit < total
const documents = await db
.select({
id: document.id,
filename: document.filename,
fileUrl: document.fileUrl,
fileSize: document.fileSize,
mimeType: document.mimeType,
chunkCount: document.chunkCount,
tokenCount: document.tokenCount,
characterCount: document.characterCount,
processingStatus: document.processingStatus,
processingStartedAt: document.processingStartedAt,
processingCompletedAt: document.processingCompletedAt,
processingError: document.processingError,
enabled: document.enabled,
uploadedAt: document.uploadedAt,
// Include tags in response
tag1: document.tag1,
tag2: document.tag2,
tag3: document.tag3,
tag4: document.tag4,
tag5: document.tag5,
tag6: document.tag6,
tag7: document.tag7,
})
.from(document)
.where(and(...whereConditions))
.orderBy(desc(document.uploadedAt))
.limit(limit)
.offset(offset)
logger.info(
`[${requestId}] Retrieved ${documents.length} documents (${offset}-${offset + documents.length} of ${total}) for knowledge base ${knowledgeBaseId}`
const result = await getDocuments(
knowledgeBaseId,
{
includeDisabled,
search,
limit,
offset,
...(sortBy && { sortBy }),
...(sortOrder && { sortOrder }),
},
requestId
)
return NextResponse.json({
success: true,
data: {
documents,
pagination: {
total,
limit,
offset,
hasMore,
},
documents: result.documents,
pagination: result.pagination,
},
})
} catch (error) {
@@ -462,80 +172,21 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if this is a bulk operation
if (body.bulk === true) {
// Handle bulk processing (replaces process-documents endpoint)
try {
const validatedData = BulkCreateDocumentsSchema.parse(body)
const createdDocuments = await db.transaction(async (tx) => {
const documentPromises = validatedData.documents.map(async (docData) => {
const documentId = randomUUID()
const now = new Date()
// Process documentTagsData if provided (for knowledge base block)
let processedTags: Record<string, string | null> = {
tag1: null,
tag2: null,
tag3: null,
tag4: null,
tag5: null,
tag6: null,
tag7: null,
}
if (docData.documentTagsData) {
try {
const tagData = JSON.parse(docData.documentTagsData)
if (Array.isArray(tagData)) {
processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId)
}
} catch (error) {
logger.warn(
`[${requestId}] Failed to parse documentTagsData for bulk document:`,
error
)
}
}
const newDocument = {
id: documentId,
knowledgeBaseId,
filename: docData.filename,
fileUrl: docData.fileUrl,
fileSize: docData.fileSize,
mimeType: docData.mimeType,
chunkCount: 0,
tokenCount: 0,
characterCount: 0,
processingStatus: 'pending' as const,
enabled: true,
uploadedAt: now,
// Use processed tags if available, otherwise fall back to individual tag fields
tag1: processedTags.tag1 || docData.tag1 || null,
tag2: processedTags.tag2 || docData.tag2 || null,
tag3: processedTags.tag3 || docData.tag3 || null,
tag4: processedTags.tag4 || docData.tag4 || null,
tag5: processedTags.tag5 || docData.tag5 || null,
tag6: processedTags.tag6 || docData.tag6 || null,
tag7: processedTags.tag7 || docData.tag7 || null,
}
await tx.insert(document).values(newDocument)
logger.info(
`[${requestId}] Document record created: ${documentId} for file: ${docData.filename}`
)
return { documentId, ...docData }
})
return await Promise.all(documentPromises)
})
const createdDocuments = await createDocumentRecords(
validatedData.documents,
knowledgeBaseId,
requestId
)
logger.info(
`[${requestId}] Starting controlled async processing of ${createdDocuments.length} documents`
)
processDocumentsWithConcurrencyControl(
processDocumentsWithQueue(
createdDocuments,
knowledgeBaseId,
validatedData.processingOptions,
@@ -555,9 +206,9 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
})),
processingMethod: 'background',
processingConfig: {
maxConcurrentDocuments: PROCESSING_CONFIG.maxConcurrentDocuments,
batchSize: PROCESSING_CONFIG.batchSize,
totalBatches: Math.ceil(createdDocuments.length / PROCESSING_CONFIG.batchSize),
maxConcurrentDocuments: getProcessingConfig().maxConcurrentDocuments,
batchSize: getProcessingConfig().batchSize,
totalBatches: Math.ceil(createdDocuments.length / getProcessingConfig().batchSize),
},
},
})
@@ -578,52 +229,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
try {
const validatedData = CreateDocumentSchema.parse(body)
const documentId = randomUUID()
const now = new Date()
// Process structured tag data if provided
let processedTags: Record<string, string | null> = {
tag1: validatedData.tag1 || null,
tag2: validatedData.tag2 || null,
tag3: validatedData.tag3 || null,
tag4: validatedData.tag4 || null,
tag5: validatedData.tag5 || null,
tag6: validatedData.tag6 || null,
tag7: validatedData.tag7 || null,
}
if (validatedData.documentTagsData) {
try {
const tagData = JSON.parse(validatedData.documentTagsData)
if (Array.isArray(tagData)) {
// Process structured tag data and create tag definitions
processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId)
}
} catch (error) {
logger.warn(`[${requestId}] Failed to parse documentTagsData:`, error)
}
}
const newDocument = {
id: documentId,
knowledgeBaseId,
filename: validatedData.filename,
fileUrl: validatedData.fileUrl,
fileSize: validatedData.fileSize,
mimeType: validatedData.mimeType,
chunkCount: 0,
tokenCount: 0,
characterCount: 0,
enabled: true,
uploadedAt: now,
...processedTags,
}
await db.insert(document).values(newDocument)
logger.info(
`[${requestId}] Document created: ${documentId} in knowledge base ${knowledgeBaseId}`
)
const newDocument = await createSingleDocument(validatedData, knowledgeBaseId, requestId)
return NextResponse.json({
success: true,
@@ -649,7 +255,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
}
export async function PATCH(req: NextRequest, { params }: { params: Promise<{ id: string }> }) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
const { id: knowledgeBaseId } = await params
try {
@@ -678,89 +284,28 @@ export async function PATCH(req: NextRequest, { params }: { params: Promise<{ id
const validatedData = BulkUpdateDocumentsSchema.parse(body)
const { operation, documentIds } = validatedData
logger.info(
`[${requestId}] Starting bulk ${operation} operation on ${documentIds.length} documents in knowledge base ${knowledgeBaseId}`
)
// Verify all documents belong to this knowledge base and user has access
const documentsToUpdate = await db
.select({
id: document.id,
enabled: document.enabled,
})
.from(document)
.where(
and(
eq(document.knowledgeBaseId, knowledgeBaseId),
inArray(document.id, documentIds),
isNull(document.deletedAt)
)
)
if (documentsToUpdate.length === 0) {
return NextResponse.json({ error: 'No valid documents found to update' }, { status: 404 })
}
if (documentsToUpdate.length !== documentIds.length) {
logger.warn(
`[${requestId}] Some documents not found or don't belong to knowledge base. Requested: ${documentIds.length}, Found: ${documentsToUpdate.length}`
)
}
// Perform the bulk operation
let updateResult: Array<{ id: string; enabled?: boolean; deletedAt?: Date | null }>
let successCount: number
if (operation === 'delete') {
// Handle bulk soft delete
updateResult = await db
.update(document)
.set({
deletedAt: new Date(),
})
.where(
and(
eq(document.knowledgeBaseId, knowledgeBaseId),
inArray(document.id, documentIds),
isNull(document.deletedAt)
)
)
.returning({ id: document.id, deletedAt: document.deletedAt })
successCount = updateResult.length
} else {
// Handle bulk enable/disable
const enabled = operation === 'enable'
updateResult = await db
.update(document)
.set({
enabled,
})
.where(
and(
eq(document.knowledgeBaseId, knowledgeBaseId),
inArray(document.id, documentIds),
isNull(document.deletedAt)
)
)
.returning({ id: document.id, enabled: document.enabled })
successCount = updateResult.length
}
logger.info(
`[${requestId}] Bulk ${operation} operation completed: ${successCount} documents updated in knowledge base ${knowledgeBaseId}`
)
return NextResponse.json({
success: true,
data: {
try {
const result = await bulkDocumentOperation(
knowledgeBaseId,
operation,
successCount,
updatedDocuments: updateResult,
},
})
documentIds,
requestId
)
return NextResponse.json({
success: true,
data: {
operation,
successCount: result.successCount,
updatedDocuments: result.updatedDocuments,
},
})
} catch (error) {
if (error instanceof Error && error.message === 'No valid documents found to update') {
return NextResponse.json({ error: 'No valid documents found to update' }, { status: 404 })
}
throw error
}
} catch (validationError) {
if (validationError instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid bulk operation data`, {

View File

@@ -1,12 +1,9 @@
import { randomUUID } from 'crypto'
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { getMaxSlotsForFieldType, getSlotsForFieldType } from '@/lib/constants/knowledge'
import { getNextAvailableSlot, getTagDefinitions } from '@/lib/knowledge/tags/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { knowledgeBaseTagDefinitions } from '@/db/schema'
const logger = createLogger('NextAvailableSlotAPI')
@@ -31,51 +28,36 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has read access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Get available slots for this field type
const availableSlots = getSlotsForFieldType(fieldType)
const maxSlots = getMaxSlotsForFieldType(fieldType)
// Get existing definitions once and reuse
const existingDefinitions = await getTagDefinitions(knowledgeBaseId)
const usedSlots = existingDefinitions
.filter((def) => def.fieldType === fieldType)
.map((def) => def.tagSlot)
// Get existing tag definitions to find used slots for this field type
const existingDefinitions = await db
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
)
)
const usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot as string))
// Find the first available slot for this field type
let nextAvailableSlot: string | null = null
for (const slot of availableSlots) {
if (!usedSlots.has(slot)) {
nextAvailableSlot = slot
break
}
}
// Create a map for efficient lookup and pass to avoid redundant query
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot as string, def]))
const nextAvailableSlot = await getNextAvailableSlot(knowledgeBaseId, fieldType, existingBySlot)
logger.info(
`[${requestId}] Next available slot for fieldType ${fieldType}: ${nextAvailableSlot}`
)
const result = {
nextAvailableSlot,
fieldType,
usedSlots,
totalSlots: 7,
availableSlots: nextAvailableSlot ? 7 - usedSlots.length : 0,
}
return NextResponse.json({
success: true,
data: {
nextAvailableSlot,
fieldType,
usedSlots: Array.from(usedSlots),
totalSlots: maxSlots,
availableSlots: maxSlots - usedSlots.size,
},
data: result,
})
} catch (error) {
logger.error(`[${requestId}] Error getting next available slot`, error)

View File

@@ -16,9 +16,26 @@ mockKnowledgeSchemas()
mockDrizzleOrm()
mockConsoleLogger()
vi.mock('@/lib/knowledge/service', () => ({
getKnowledgeBaseById: vi.fn(),
updateKnowledgeBase: vi.fn(),
deleteKnowledgeBase: vi.fn(),
}))
vi.mock('@/app/api/knowledge/utils', () => ({
checkKnowledgeBaseAccess: vi.fn(),
checkKnowledgeBaseWriteAccess: vi.fn(),
}))
describe('Knowledge Base By ID API Route', () => {
const mockAuth$ = mockAuth()
let mockGetKnowledgeBaseById: any
let mockUpdateKnowledgeBase: any
let mockDeleteKnowledgeBase: any
let mockCheckKnowledgeBaseAccess: any
let mockCheckKnowledgeBaseWriteAccess: any
const mockDbChain = {
select: vi.fn().mockReturnThis(),
from: vi.fn().mockReturnThis(),
@@ -62,6 +79,15 @@ describe('Knowledge Base By ID API Route', () => {
vi.stubGlobal('crypto', {
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),
})
const knowledgeService = await import('@/lib/knowledge/service')
const knowledgeUtils = await import('@/app/api/knowledge/utils')
mockGetKnowledgeBaseById = knowledgeService.getKnowledgeBaseById as any
mockUpdateKnowledgeBase = knowledgeService.updateKnowledgeBase as any
mockDeleteKnowledgeBase = knowledgeService.deleteKnowledgeBase as any
mockCheckKnowledgeBaseAccess = knowledgeUtils.checkKnowledgeBaseAccess as any
mockCheckKnowledgeBaseWriteAccess = knowledgeUtils.checkKnowledgeBaseWriteAccess as any
})
afterEach(() => {
@@ -74,9 +100,12 @@ describe('Knowledge Base By ID API Route', () => {
it('should retrieve knowledge base successfully for authenticated user', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.limit.mockResolvedValueOnce([mockKnowledgeBase])
mockGetKnowledgeBaseById.mockResolvedValueOnce(mockKnowledgeBase)
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/route')
@@ -87,7 +116,8 @@ describe('Knowledge Base By ID API Route', () => {
expect(data.success).toBe(true)
expect(data.data.id).toBe('kb-123')
expect(data.data.name).toBe('Test Knowledge Base')
expect(mockDbChain.select).toHaveBeenCalled()
expect(mockCheckKnowledgeBaseAccess).toHaveBeenCalledWith('kb-123', 'user-123')
expect(mockGetKnowledgeBaseById).toHaveBeenCalledWith('kb-123')
})
it('should return unauthorized for unauthenticated user', async () => {
@@ -105,7 +135,10 @@ describe('Knowledge Base By ID API Route', () => {
it('should return not found for non-existent knowledge base', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockResolvedValueOnce([])
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
hasAccess: false,
notFound: true,
})
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/route')
@@ -119,7 +152,10 @@ describe('Knowledge Base By ID API Route', () => {
it('should return unauthorized for knowledge base owned by different user', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'different-user' }])
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
hasAccess: false,
notFound: false,
})
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/route')
@@ -130,9 +166,29 @@ describe('Knowledge Base By ID API Route', () => {
expect(data.error).toBe('Unauthorized')
})
it('should return not found when service returns null', async () => {
mockAuth$.mockAuthenticatedUser()
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockGetKnowledgeBaseById.mockResolvedValueOnce(null)
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/route')
const response = await GET(req, { params: mockParams })
const data = await response.json()
expect(response.status).toBe(404)
expect(data.error).toBe('Knowledge base not found')
})
it('should handle database errors', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockRejectedValueOnce(new Error('Database error'))
mockCheckKnowledgeBaseAccess.mockRejectedValueOnce(new Error('Database error'))
const req = createMockRequest('GET')
const { GET } = await import('@/app/api/knowledge/[id]/route')
@@ -156,13 +212,13 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.where.mockResolvedValueOnce(undefined)
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([{ ...mockKnowledgeBase, ...validUpdateData }])
const updatedKnowledgeBase = { ...mockKnowledgeBase, ...validUpdateData }
mockUpdateKnowledgeBase.mockResolvedValueOnce(updatedKnowledgeBase)
const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/route')
@@ -172,7 +228,16 @@ describe('Knowledge Base By ID API Route', () => {
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(data.data.name).toBe('Updated Knowledge Base')
expect(mockDbChain.update).toHaveBeenCalled()
expect(mockCheckKnowledgeBaseWriteAccess).toHaveBeenCalledWith('kb-123', 'user-123')
expect(mockUpdateKnowledgeBase).toHaveBeenCalledWith(
'kb-123',
{
name: validUpdateData.name,
description: validUpdateData.description,
chunkingConfig: undefined,
},
expect.any(String)
)
})
it('should return unauthorized for unauthenticated user', async () => {
@@ -192,8 +257,10 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: false,
notFound: true,
})
const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/route')
@@ -209,8 +276,10 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
const invalidData = {
name: '',
@@ -229,9 +298,13 @@ describe('Knowledge Base By ID API Route', () => {
it('should handle database errors during update', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
// Mock successful write access check
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.where.mockRejectedValueOnce(new Error('Database error'))
mockUpdateKnowledgeBase.mockRejectedValueOnce(new Error('Database error'))
const req = createMockRequest('PUT', validUpdateData)
const { PUT } = await import('@/app/api/knowledge/[id]/route')
@@ -251,10 +324,12 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.where.mockResolvedValueOnce(undefined)
mockDeleteKnowledgeBase.mockResolvedValueOnce(undefined)
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
@@ -264,7 +339,8 @@ describe('Knowledge Base By ID API Route', () => {
expect(response.status).toBe(200)
expect(data.success).toBe(true)
expect(data.data.message).toBe('Knowledge base deleted successfully')
expect(mockDbChain.update).toHaveBeenCalled()
expect(mockCheckKnowledgeBaseWriteAccess).toHaveBeenCalledWith('kb-123', 'user-123')
expect(mockDeleteKnowledgeBase).toHaveBeenCalledWith('kb-123', expect.any(String))
})
it('should return unauthorized for unauthenticated user', async () => {
@@ -284,8 +360,10 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: false,
notFound: true,
})
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
@@ -301,8 +379,10 @@ describe('Knowledge Base By ID API Route', () => {
resetMocks()
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'different-user' }])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: false,
notFound: false,
})
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
@@ -316,9 +396,12 @@ describe('Knowledge Base By ID API Route', () => {
it('should handle database errors during delete', async () => {
mockAuth$.mockAuthenticatedUser()
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
hasAccess: true,
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
})
mockDbChain.where.mockRejectedValueOnce(new Error('Database error'))
mockDeleteKnowledgeBase.mockRejectedValueOnce(new Error('Database error'))
const req = createMockRequest('DELETE')
const { DELETE } = await import('@/app/api/knowledge/[id]/route')

View File

@@ -1,11 +1,13 @@
import { and, eq, isNull } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import {
deleteKnowledgeBase,
getKnowledgeBaseById,
updateKnowledgeBase,
} from '@/lib/knowledge/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { knowledgeBase } from '@/db/schema'
const logger = createLogger('KnowledgeBaseByIdAPI')
@@ -48,13 +50,9 @@ export async function GET(_req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const knowledgeBases = await db
.select()
.from(knowledgeBase)
.where(and(eq(knowledgeBase.id, id), isNull(knowledgeBase.deletedAt)))
.limit(1)
const knowledgeBaseData = await getKnowledgeBaseById(id)
if (knowledgeBases.length === 0) {
if (!knowledgeBaseData) {
return NextResponse.json({ error: 'Knowledge base not found' }, { status: 404 })
}
@@ -62,7 +60,7 @@ export async function GET(_req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({
success: true,
data: knowledgeBases[0],
data: knowledgeBaseData,
})
} catch (error) {
logger.error(`[${requestId}] Error fetching knowledge base`, error)
@@ -99,42 +97,21 @@ export async function PUT(req: NextRequest, { params }: { params: Promise<{ id:
try {
const validatedData = UpdateKnowledgeBaseSchema.parse(body)
const updateData: any = {
updatedAt: new Date(),
}
if (validatedData.name !== undefined) updateData.name = validatedData.name
if (validatedData.description !== undefined)
updateData.description = validatedData.description
if (validatedData.workspaceId !== undefined)
updateData.workspaceId = validatedData.workspaceId
// Handle embedding model and dimension together to ensure consistency
if (
validatedData.embeddingModel !== undefined ||
validatedData.embeddingDimension !== undefined
) {
updateData.embeddingModel = 'text-embedding-3-small'
updateData.embeddingDimension = 1536
}
if (validatedData.chunkingConfig !== undefined)
updateData.chunkingConfig = validatedData.chunkingConfig
await db.update(knowledgeBase).set(updateData).where(eq(knowledgeBase.id, id))
// Fetch the updated knowledge base
const updatedKnowledgeBase = await db
.select()
.from(knowledgeBase)
.where(eq(knowledgeBase.id, id))
.limit(1)
const updatedKnowledgeBase = await updateKnowledgeBase(
id,
{
name: validatedData.name,
description: validatedData.description,
chunkingConfig: validatedData.chunkingConfig,
},
requestId
)
logger.info(`[${requestId}] Knowledge base updated: ${id} for user ${session.user.id}`)
return NextResponse.json({
success: true,
data: updatedKnowledgeBase[0],
data: updatedKnowledgeBase,
})
} catch (validationError) {
if (validationError instanceof z.ZodError) {
@@ -178,14 +155,7 @@ export async function DELETE(_req: NextRequest, { params }: { params: Promise<{
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Soft delete by setting deletedAt timestamp
await db
.update(knowledgeBase)
.set({
deletedAt: new Date(),
updatedAt: new Date(),
})
.where(eq(knowledgeBase.id, id))
await deleteKnowledgeBase(id, requestId)
logger.info(`[${requestId}] Knowledge base deleted: ${id} for user ${session.user.id}`)

View File

@@ -1,11 +1,9 @@
import { randomUUID } from 'crypto'
import { and, eq, isNotNull } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { deleteTagDefinition } from '@/lib/knowledge/tags/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, embedding, knowledgeBaseTagDefinitions } from '@/db/schema'
export const dynamic = 'force-dynamic'
@@ -29,87 +27,16 @@ export async function DELETE(
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Get the tag definition to find which slot it uses
const tagDefinition = await db
.select({
id: knowledgeBaseTagDefinitions.id,
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
})
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.id, tagId),
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId)
)
)
.limit(1)
if (tagDefinition.length === 0) {
return NextResponse.json({ error: 'Tag definition not found' }, { status: 404 })
}
const tagDef = tagDefinition[0]
// Delete the tag definition and clear all document tags in a transaction
await db.transaction(async (tx) => {
logger.info(`[${requestId}] Starting transaction to delete ${tagDef.tagSlot}`)
try {
// Clear the tag from documents that actually have this tag set
logger.info(`[${requestId}] Clearing tag from documents...`)
await tx
.update(document)
.set({ [tagDef.tagSlot]: null })
.where(
and(
eq(document.knowledgeBaseId, knowledgeBaseId),
isNotNull(document[tagDef.tagSlot as keyof typeof document.$inferSelect])
)
)
logger.info(`[${requestId}] Documents updated successfully`)
// Clear the tag from embeddings that actually have this tag set
logger.info(`[${requestId}] Clearing tag from embeddings...`)
await tx
.update(embedding)
.set({ [tagDef.tagSlot]: null })
.where(
and(
eq(embedding.knowledgeBaseId, knowledgeBaseId),
isNotNull(embedding[tagDef.tagSlot as keyof typeof embedding.$inferSelect])
)
)
logger.info(`[${requestId}] Embeddings updated successfully`)
// Delete the tag definition
logger.info(`[${requestId}] Deleting tag definition...`)
await tx
.delete(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.id, tagId))
logger.info(`[${requestId}] Tag definition deleted successfully`)
} catch (error) {
logger.error(`[${requestId}] Error in transaction:`, error)
throw error
}
})
logger.info(
`[${requestId}] Successfully deleted tag definition ${tagDef.displayName} (${tagDef.tagSlot})`
)
const deletedTag = await deleteTagDefinition(tagId, requestId)
return NextResponse.json({
success: true,
message: `Tag definition "${tagDef.displayName}" deleted successfully`,
message: `Tag definition "${deletedTag.displayName}" deleted successfully`,
})
} catch (error) {
logger.error(`[${requestId}] Error deleting tag definition`, error)

View File

@@ -1,11 +1,11 @@
import { randomUUID } from 'crypto'
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { SUPPORTED_FIELD_TYPES } from '@/lib/knowledge/consts'
import { createTagDefinition, getTagDefinitions } from '@/lib/knowledge/tags/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { knowledgeBaseTagDefinitions } from '@/db/schema'
export const dynamic = 'force-dynamic'
@@ -24,25 +24,12 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Get tag definitions for the knowledge base
const tagDefinitions = await db
.select({
id: knowledgeBaseTagDefinitions.id,
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
fieldType: knowledgeBaseTagDefinitions.fieldType,
createdAt: knowledgeBaseTagDefinitions.createdAt,
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
})
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
.orderBy(knowledgeBaseTagDefinitions.tagSlot)
const tagDefinitions = await getTagDefinitions(knowledgeBaseId)
logger.info(`[${requestId}] Retrieved ${tagDefinitions.length} tag definitions`)
@@ -69,68 +56,43 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
const body = await req.json()
const { tagSlot, displayName, fieldType } = body
if (!tagSlot || !displayName || !fieldType) {
return NextResponse.json(
{ error: 'tagSlot, displayName, and fieldType are required' },
{ status: 400 }
)
}
const CreateTagDefinitionSchema = z.object({
tagSlot: z.string().min(1, 'Tag slot is required'),
displayName: z.string().min(1, 'Display name is required'),
fieldType: z.enum(SUPPORTED_FIELD_TYPES as [string, ...string[]], {
errorMap: () => ({ message: 'Invalid field type' }),
}),
})
// Check if tag slot is already used
const existingTag = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseTagDefinitions.tagSlot, tagSlot)
let validatedData
try {
validatedData = CreateTagDefinitionSchema.parse(body)
} catch (error) {
if (error instanceof z.ZodError) {
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
)
.limit(1)
if (existingTag.length > 0) {
return NextResponse.json({ error: 'Tag slot is already in use' }, { status: 409 })
}
throw error
}
// Check if display name is already used
const existingName = await db
.select()
.from(knowledgeBaseTagDefinitions)
.where(
and(
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseTagDefinitions.displayName, displayName)
)
)
.limit(1)
if (existingName.length > 0) {
return NextResponse.json({ error: 'Tag name is already in use' }, { status: 409 })
}
// Create the new tag definition
const newTagDefinition = {
id: randomUUID(),
knowledgeBaseId,
tagSlot,
displayName,
fieldType,
createdAt: new Date(),
updatedAt: new Date(),
}
await db.insert(knowledgeBaseTagDefinitions).values(newTagDefinition)
logger.info(`[${requestId}] Successfully created tag definition ${displayName} (${tagSlot})`)
const newTagDefinition = await createTagDefinition(
{
knowledgeBaseId,
tagSlot: validatedData.tagSlot,
displayName: validatedData.displayName,
fieldType: validatedData.fieldType,
},
requestId
)
return NextResponse.json({
success: true,

View File

@@ -1,11 +1,9 @@
import { randomUUID } from 'crypto'
import { and, eq, isNotNull } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { getTagUsage } from '@/lib/knowledge/tags/service'
import { createLogger } from '@/lib/logs/console/logger'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
export const dynamic = 'force-dynamic'
@@ -24,57 +22,15 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check if user has access to the knowledge base
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
if (!accessCheck.hasAccess) {
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
}
// Get all tag definitions for the knowledge base
const tagDefinitions = await db
.select({
id: knowledgeBaseTagDefinitions.id,
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
})
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
// Get usage statistics for each tag definition
const usageStats = await Promise.all(
tagDefinitions.map(async (tagDef) => {
// Count documents using this tag slot
const tagSlotColumn = tagDef.tagSlot as keyof typeof document.$inferSelect
const documentsWithTag = await db
.select({
id: document.id,
filename: document.filename,
[tagDef.tagSlot]: document[tagSlotColumn as keyof typeof document.$inferSelect] as any,
})
.from(document)
.where(
and(
eq(document.knowledgeBaseId, knowledgeBaseId),
isNotNull(document[tagSlotColumn as keyof typeof document.$inferSelect])
)
)
return {
tagName: tagDef.displayName,
tagSlot: tagDef.tagSlot,
documentCount: documentsWithTag.length,
documents: documentsWithTag.map((doc) => ({
id: doc.id,
name: doc.filename,
tagValue: doc[tagDef.tagSlot],
})),
}
})
)
const usageStats = await getTagUsage(knowledgeBaseId, requestId)
logger.info(
`[${requestId}] Retrieved usage statistics for ${tagDefinitions.length} tag definitions`
`[${requestId}] Retrieved usage statistics for ${usageStats.length} tag definitions`
)
return NextResponse.json({

View File

@@ -1,11 +1,8 @@
import { and, count, eq, isNotNull, isNull, or } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { getSession } from '@/lib/auth'
import { createKnowledgeBase, getKnowledgeBases } from '@/lib/knowledge/service'
import { createLogger } from '@/lib/logs/console/logger'
import { getUserEntityPermissions } from '@/lib/permissions/utils'
import { db } from '@/db'
import { document, knowledgeBase, permissions } from '@/db/schema'
const logger = createLogger('KnowledgeBaseAPI')
@@ -41,60 +38,10 @@ export async function GET(req: NextRequest) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
// Check for workspace filtering
const { searchParams } = new URL(req.url)
const workspaceId = searchParams.get('workspaceId')
// Get knowledge bases that user can access through direct ownership OR workspace permissions
const knowledgeBasesWithCounts = await db
.select({
id: knowledgeBase.id,
name: knowledgeBase.name,
description: knowledgeBase.description,
tokenCount: knowledgeBase.tokenCount,
embeddingModel: knowledgeBase.embeddingModel,
embeddingDimension: knowledgeBase.embeddingDimension,
chunkingConfig: knowledgeBase.chunkingConfig,
createdAt: knowledgeBase.createdAt,
updatedAt: knowledgeBase.updatedAt,
workspaceId: knowledgeBase.workspaceId,
docCount: count(document.id),
})
.from(knowledgeBase)
.leftJoin(
document,
and(eq(document.knowledgeBaseId, knowledgeBase.id), isNull(document.deletedAt))
)
.leftJoin(
permissions,
and(
eq(permissions.entityType, 'workspace'),
eq(permissions.entityId, knowledgeBase.workspaceId),
eq(permissions.userId, session.user.id)
)
)
.where(
and(
isNull(knowledgeBase.deletedAt),
workspaceId
? // When filtering by workspace
or(
// Knowledge bases belonging to the specified workspace (user must have workspace permissions)
and(eq(knowledgeBase.workspaceId, workspaceId), isNotNull(permissions.userId)),
// Fallback: User-owned knowledge bases without workspace (legacy)
and(eq(knowledgeBase.userId, session.user.id), isNull(knowledgeBase.workspaceId))
)
: // When not filtering by workspace, use original logic
or(
// User owns the knowledge base directly
eq(knowledgeBase.userId, session.user.id),
// User has permissions on the knowledge base's workspace
isNotNull(permissions.userId)
)
)
)
.groupBy(knowledgeBase.id)
.orderBy(knowledgeBase.createdAt)
const knowledgeBasesWithCounts = await getKnowledgeBases(session.user.id, workspaceId)
return NextResponse.json({
success: true,
@@ -121,49 +68,16 @@ export async function POST(req: NextRequest) {
try {
const validatedData = CreateKnowledgeBaseSchema.parse(body)
// If creating in a workspace, check if user has write/admin permissions
if (validatedData.workspaceId) {
const userPermission = await getUserEntityPermissions(
session.user.id,
'workspace',
validatedData.workspaceId
)
if (userPermission !== 'write' && userPermission !== 'admin') {
logger.warn(
`[${requestId}] User ${session.user.id} denied permission to create knowledge base in workspace ${validatedData.workspaceId}`
)
return NextResponse.json(
{ error: 'Insufficient permissions to create knowledge base in this workspace' },
{ status: 403 }
)
}
}
const id = crypto.randomUUID()
const now = new Date()
const newKnowledgeBase = {
id,
const createData = {
...validatedData,
userId: session.user.id,
workspaceId: validatedData.workspaceId || null,
name: validatedData.name,
description: validatedData.description || null,
tokenCount: 0,
embeddingModel: validatedData.embeddingModel,
embeddingDimension: validatedData.embeddingDimension,
chunkingConfig: validatedData.chunkingConfig || {
maxSize: 1024,
minSize: 100,
overlap: 200,
},
docCount: 0,
createdAt: now,
updatedAt: now,
}
await db.insert(knowledgeBase).values(newKnowledgeBase)
const newKnowledgeBase = await createKnowledgeBase(createData, requestId)
logger.info(`[${requestId}] Knowledge base created: ${id} for user ${session.user.id}`)
logger.info(
`[${requestId}] Knowledge base created: ${newKnowledgeBase.id} for user ${session.user.id}`
)
return NextResponse.json({
success: true,

View File

@@ -65,12 +65,14 @@ const mockHandleVectorOnlySearch = vi.fn()
const mockHandleTagAndVectorSearch = vi.fn()
const mockGetQueryStrategy = vi.fn()
const mockGenerateSearchEmbedding = vi.fn()
const mockGetDocumentNamesByIds = vi.fn()
vi.mock('./utils', () => ({
handleTagOnlySearch: mockHandleTagOnlySearch,
handleVectorOnlySearch: mockHandleVectorOnlySearch,
handleTagAndVectorSearch: mockHandleTagAndVectorSearch,
getQueryStrategy: mockGetQueryStrategy,
generateSearchEmbedding: mockGenerateSearchEmbedding,
getDocumentNamesByIds: mockGetDocumentNamesByIds,
APIError: class APIError extends Error {
public status: number
constructor(message: string, status: number) {
@@ -146,6 +148,10 @@ describe('Knowledge Search API Route', () => {
singleQueryOptimized: true,
})
mockGenerateSearchEmbedding.mockClear().mockResolvedValue([0.1, 0.2, 0.3, 0.4, 0.5])
mockGetDocumentNamesByIds.mockClear().mockResolvedValue({
doc1: 'Document 1',
doc2: 'Document 2',
})
vi.stubGlobal('crypto', {
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),

View File

@@ -1,16 +1,15 @@
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { TAG_SLOTS } from '@/lib/constants/knowledge'
import { TAG_SLOTS } from '@/lib/knowledge/consts'
import { getDocumentTagDefinitions } from '@/lib/knowledge/tags/service'
import { createLogger } from '@/lib/logs/console/logger'
import { estimateTokenCount } from '@/lib/tokenization/estimators'
import { getUserId } from '@/app/api/auth/oauth/utils'
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
import { db } from '@/db'
import { knowledgeBaseTagDefinitions } from '@/db/schema'
import { calculateCost } from '@/providers/utils'
import {
generateSearchEmbedding,
getDocumentNamesByIds,
getQueryStrategy,
handleTagAndVectorSearch,
handleTagOnlySearch,
@@ -79,14 +78,13 @@ export async function POST(request: NextRequest) {
? validatedData.knowledgeBaseIds
: [validatedData.knowledgeBaseIds]
// Check access permissions for each knowledge base using proper workspace-based permissions
const accessibleKbIds: string[] = []
for (const kbId of knowledgeBaseIds) {
const accessCheck = await checkKnowledgeBaseAccess(kbId, userId)
if (accessCheck.hasAccess) {
accessibleKbIds.push(kbId)
}
}
// Check access permissions in parallel for performance
const accessChecks = await Promise.all(
knowledgeBaseIds.map((kbId) => checkKnowledgeBaseAccess(kbId, userId))
)
const accessibleKbIds: string[] = knowledgeBaseIds.filter(
(_, idx) => accessChecks[idx]?.hasAccess
)
// Map display names to tag slots for filtering
let mappedFilters: Record<string, string> = {}
@@ -94,13 +92,7 @@ export async function POST(request: NextRequest) {
try {
// Fetch tag definitions for the first accessible KB (since we're using single KB now)
const kbId = accessibleKbIds[0]
const tagDefs = await db
.select({
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
})
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, kbId))
const tagDefs = await getDocumentTagDefinitions(kbId)
logger.debug(`[${requestId}] Found tag definitions:`, tagDefs)
logger.debug(`[${requestId}] Original filters:`, validatedData.filters)
@@ -145,7 +137,10 @@ export async function POST(request: NextRequest) {
// Generate query embedding only if query is provided
const hasQuery = validatedData.query && validatedData.query.trim().length > 0
const queryEmbedding = hasQuery ? await generateSearchEmbedding(validatedData.query!) : null
// Start embedding generation early and await when needed
const queryEmbeddingPromise = hasQuery
? generateSearchEmbedding(validatedData.query!)
: Promise.resolve(null)
// Check if any requested knowledge bases were not accessible
const inaccessibleKbIds = knowledgeBaseIds.filter((id) => !accessibleKbIds.includes(id))
@@ -173,7 +168,7 @@ export async function POST(request: NextRequest) {
// Tag + Vector search
logger.debug(`[${requestId}] Executing tag + vector search with filters:`, mappedFilters)
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
const queryVector = JSON.stringify(queryEmbedding)
const queryVector = JSON.stringify(await queryEmbeddingPromise)
results = await handleTagAndVectorSearch({
knowledgeBaseIds: accessibleKbIds,
@@ -186,7 +181,7 @@ export async function POST(request: NextRequest) {
// Vector-only search
logger.debug(`[${requestId}] Executing vector-only search`)
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
const queryVector = JSON.stringify(queryEmbedding)
const queryVector = JSON.stringify(await queryEmbeddingPromise)
results = await handleVectorOnlySearch({
knowledgeBaseIds: accessibleKbIds,
@@ -221,30 +216,32 @@ export async function POST(request: NextRequest) {
}
// Fetch tag definitions for display name mapping (reuse the same fetch from filtering)
const tagDefinitionsMap: Record<string, Record<string, string>> = {}
for (const kbId of accessibleKbIds) {
try {
const tagDefs = await db
.select({
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
displayName: knowledgeBaseTagDefinitions.displayName,
const tagDefsResults = await Promise.all(
accessibleKbIds.map(async (kbId) => {
try {
const tagDefs = await getDocumentTagDefinitions(kbId)
const map: Record<string, string> = {}
tagDefs.forEach((def) => {
map[def.tagSlot] = def.displayName
})
.from(knowledgeBaseTagDefinitions)
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, kbId))
return { kbId, map }
} catch (error) {
logger.warn(
`[${requestId}] Failed to fetch tag definitions for display mapping:`,
error
)
return { kbId, map: {} as Record<string, string> }
}
})
)
const tagDefinitionsMap: Record<string, Record<string, string>> = {}
tagDefsResults.forEach(({ kbId, map }) => {
tagDefinitionsMap[kbId] = map
})
tagDefinitionsMap[kbId] = {}
tagDefs.forEach((def) => {
tagDefinitionsMap[kbId][def.tagSlot] = def.displayName
})
logger.debug(
`[${requestId}] Display mapping - KB ${kbId} tag definitions:`,
tagDefinitionsMap[kbId]
)
} catch (error) {
logger.warn(`[${requestId}] Failed to fetch tag definitions for display mapping:`, error)
tagDefinitionsMap[kbId] = {}
}
}
// Fetch document names for the results
const documentIds = results.map((result) => result.documentId)
const documentNameMap = await getDocumentNamesByIds(documentIds)
return NextResponse.json({
success: true,
@@ -271,11 +268,11 @@ export async function POST(request: NextRequest) {
})
return {
id: result.id,
content: result.content,
documentId: result.documentId,
documentName: documentNameMap[result.documentId] || undefined,
content: result.content,
chunkIndex: result.chunkIndex,
tags, // Clean display name mapped tags
metadata: tags, // Clean display name mapped tags
similarity: hasQuery ? 1 - result.distance : 1, // Perfect similarity for tag-only searches
}
}),

View File

@@ -16,7 +16,7 @@ vi.mock('@/lib/logs/console/logger', () => ({
})),
}))
vi.mock('@/db')
vi.mock('@/lib/documents/utils', () => ({
vi.mock('@/lib/knowledge/documents/utils', () => ({
retryWithExponentialBackoff: (fn: any) => fn(),
}))

View File

@@ -1,10 +1,34 @@
import { and, eq, inArray, sql } from 'drizzle-orm'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { embedding } from '@/db/schema'
import { document, embedding } from '@/db/schema'
const logger = createLogger('KnowledgeSearchUtils')
export async function getDocumentNamesByIds(
documentIds: string[]
): Promise<Record<string, string>> {
if (documentIds.length === 0) {
return {}
}
const uniqueIds = [...new Set(documentIds)]
const documents = await db
.select({
id: document.id,
filename: document.filename,
})
.from(document)
.where(inArray(document.id, uniqueIds))
const documentNameMap: Record<string, string> = {}
documents.forEach((doc) => {
documentNameMap[doc.id] = doc.filename
})
return documentNameMap
}
export interface SearchResult {
id: string
content: string

View File

@@ -21,11 +21,11 @@ vi.mock('@/lib/env', () => ({
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
}))
vi.mock('@/lib/documents/utils', () => ({
vi.mock('@/lib/knowledge/documents/utils', () => ({
retryWithExponentialBackoff: (fn: any) => fn(),
}))
vi.mock('@/lib/documents/document-processor', () => ({
vi.mock('@/lib/knowledge/documents/document-processor', () => ({
processDocument: vi.fn().mockResolvedValue({
chunks: [
{
@@ -149,12 +149,12 @@ vi.mock('@/db', () => {
}
})
import { generateEmbeddings } from '@/lib/embeddings/utils'
import { processDocumentAsync } from '@/lib/knowledge/documents/service'
import {
checkChunkAccess,
checkDocumentAccess,
checkKnowledgeBaseAccess,
generateEmbeddings,
processDocumentAsync,
} from '@/app/api/knowledge/utils'
describe('Knowledge Utils', () => {

View File

@@ -1,35 +1,8 @@
import crypto from 'crypto'
import { and, eq, isNull } from 'drizzle-orm'
import { processDocument } from '@/lib/documents/document-processor'
import { generateEmbeddings } from '@/lib/embeddings/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { getUserEntityPermissions } from '@/lib/permissions/utils'
import { db } from '@/db'
import { document, embedding, knowledgeBase } from '@/db/schema'
const logger = createLogger('KnowledgeUtils')
const TIMEOUTS = {
OVERALL_PROCESSING: 150000, // 150 seconds (2.5 minutes)
EMBEDDINGS_API: 60000, // 60 seconds per batch
} as const
/**
* Create a timeout wrapper for async operations
*/
function withTimeout<T>(
promise: Promise<T>,
timeoutMs: number,
operation = 'Operation'
): Promise<T> {
return Promise.race([
promise,
new Promise<never>((_, reject) =>
setTimeout(() => reject(new Error(`${operation} timed out after ${timeoutMs}ms`)), timeoutMs)
),
])
}
export interface KnowledgeBaseData {
id: string
userId: string
@@ -380,154 +353,3 @@ export async function checkChunkAccess(
knowledgeBase: kbAccess.knowledgeBase!,
}
}
// Export for external use
export { generateEmbeddings }
/**
* Process a document asynchronously with full error handling
*/
export async function processDocumentAsync(
knowledgeBaseId: string,
documentId: string,
docData: {
filename: string
fileUrl: string
fileSize: number
mimeType: string
},
processingOptions: {
chunkSize?: number
minCharactersPerChunk?: number
recipe?: string
lang?: string
chunkOverlap?: number
}
): Promise<void> {
const startTime = Date.now()
try {
logger.info(`[${documentId}] Starting document processing: ${docData.filename}`)
// Set status to processing
await db
.update(document)
.set({
processingStatus: 'processing',
processingStartedAt: new Date(),
processingError: null, // Clear any previous error
})
.where(eq(document.id, documentId))
logger.info(`[${documentId}] Status updated to 'processing', starting document processor`)
// Wrap the entire processing operation with a 5-minute timeout
await withTimeout(
(async () => {
const processed = await processDocument(
docData.fileUrl,
docData.filename,
docData.mimeType,
processingOptions.chunkSize || 1000,
processingOptions.chunkOverlap || 200,
processingOptions.minCharactersPerChunk || 1
)
const now = new Date()
logger.info(
`[${documentId}] Document parsed successfully, generating embeddings for ${processed.chunks.length} chunks`
)
const chunkTexts = processed.chunks.map((chunk) => chunk.text)
const embeddings = chunkTexts.length > 0 ? await generateEmbeddings(chunkTexts) : []
logger.info(`[${documentId}] Embeddings generated, fetching document tags`)
// Fetch document to get tags
const documentRecord = await db
.select({
tag1: document.tag1,
tag2: document.tag2,
tag3: document.tag3,
tag4: document.tag4,
tag5: document.tag5,
tag6: document.tag6,
tag7: document.tag7,
})
.from(document)
.where(eq(document.id, documentId))
.limit(1)
const documentTags = documentRecord[0] || {}
logger.info(`[${documentId}] Creating embedding records with tags`)
const embeddingRecords = processed.chunks.map((chunk, chunkIndex) => ({
id: crypto.randomUUID(),
knowledgeBaseId,
documentId,
chunkIndex,
chunkHash: crypto.createHash('sha256').update(chunk.text).digest('hex'),
content: chunk.text,
contentLength: chunk.text.length,
tokenCount: Math.ceil(chunk.text.length / 4),
embedding: embeddings[chunkIndex] || null,
embeddingModel: 'text-embedding-3-small',
startOffset: chunk.metadata.startIndex,
endOffset: chunk.metadata.endIndex,
// Copy tags from document
tag1: documentTags.tag1,
tag2: documentTags.tag2,
tag3: documentTags.tag3,
tag4: documentTags.tag4,
tag5: documentTags.tag5,
tag6: documentTags.tag6,
tag7: documentTags.tag7,
createdAt: now,
updatedAt: now,
}))
await db.transaction(async (tx) => {
if (embeddingRecords.length > 0) {
await tx.insert(embedding).values(embeddingRecords)
}
await tx
.update(document)
.set({
chunkCount: processed.metadata.chunkCount,
tokenCount: processed.metadata.tokenCount,
characterCount: processed.metadata.characterCount,
processingStatus: 'completed',
processingCompletedAt: now,
processingError: null,
})
.where(eq(document.id, documentId))
})
})(),
TIMEOUTS.OVERALL_PROCESSING,
'Document processing'
)
const processingTime = Date.now() - startTime
logger.info(`[${documentId}] Successfully processed document in ${processingTime}ms`)
} catch (error) {
const processingTime = Date.now() - startTime
logger.error(`[${documentId}] Failed to process document after ${processingTime}ms:`, {
error: error instanceof Error ? error.message : 'Unknown error',
stack: error instanceof Error ? error.stack : undefined,
filename: docData.filename,
fileUrl: docData.fileUrl,
mimeType: docData.mimeType,
})
await db
.update(document)
.set({
processingStatus: 'failed',
processingError: error instanceof Error ? error.message : 'Unknown error',
processingCompletedAt: new Date(),
})
.where(eq(document.id, documentId))
}
}

View File

@@ -4,7 +4,7 @@ import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { workflowExecutionLogs, workflowExecutionSnapshots } from '@/db/schema'
const logger = createLogger('FrozenCanvasAPI')
const logger = createLogger('LogsByExecutionIdAPI')
export async function GET(
_request: NextRequest,
@@ -13,7 +13,7 @@ export async function GET(
try {
const { executionId } = await params
logger.debug(`Fetching frozen canvas data for execution: ${executionId}`)
logger.debug(`Fetching execution data for: ${executionId}`)
// Get the workflow execution log to find the snapshot
const [workflowLog] = await db
@@ -50,14 +50,14 @@ export async function GET(
},
}
logger.debug(`Successfully fetched frozen canvas data for execution: ${executionId}`)
logger.debug(`Successfully fetched execution data for: ${executionId}`)
logger.debug(
`Workflow state contains ${Object.keys((snapshot.stateData as any)?.blocks || {}).length} blocks`
)
return NextResponse.json(response)
} catch (error) {
logger.error('Error fetching frozen canvas data:', error)
return NextResponse.json({ error: 'Failed to fetch frozen canvas data' }, { status: 500 })
logger.error('Error fetching execution data:', error)
return NextResponse.json({ error: 'Failed to fetch execution data' }, { status: 500 })
}
}

View File

@@ -0,0 +1,198 @@
import { randomUUID } from 'crypto'
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import {
invitation,
member,
organization,
permissions,
user,
type WorkspaceInvitationStatus,
workspaceInvitation,
} from '@/db/schema'
const logger = createLogger('OrganizationInvitation')
// Get invitation details
export async function GET(
_req: NextRequest,
{ params }: { params: Promise<{ id: string; invitationId: string }> }
) {
const { id: organizationId, invitationId } = await params
const session = await getSession()
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
try {
const orgInvitation = await db
.select()
.from(invitation)
.where(and(eq(invitation.id, invitationId), eq(invitation.organizationId, organizationId)))
.then((rows) => rows[0])
if (!orgInvitation) {
return NextResponse.json({ error: 'Invitation not found' }, { status: 404 })
}
const org = await db
.select()
.from(organization)
.where(eq(organization.id, organizationId))
.then((rows) => rows[0])
if (!org) {
return NextResponse.json({ error: 'Organization not found' }, { status: 404 })
}
return NextResponse.json({
invitation: orgInvitation,
organization: org,
})
} catch (error) {
logger.error('Error fetching organization invitation:', error)
return NextResponse.json({ error: 'Failed to fetch invitation' }, { status: 500 })
}
}
export async function PUT(
req: NextRequest,
{ params }: { params: Promise<{ id: string; invitationId: string }> }
) {
const { id: organizationId, invitationId } = await params
const session = await getSession()
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
try {
const { status } = await req.json()
if (!status || !['accepted', 'rejected', 'cancelled'].includes(status)) {
return NextResponse.json(
{ error: 'Invalid status. Must be "accepted", "rejected", or "cancelled"' },
{ status: 400 }
)
}
const orgInvitation = await db
.select()
.from(invitation)
.where(and(eq(invitation.id, invitationId), eq(invitation.organizationId, organizationId)))
.then((rows) => rows[0])
if (!orgInvitation) {
return NextResponse.json({ error: 'Invitation not found' }, { status: 404 })
}
if (orgInvitation.status !== 'pending') {
return NextResponse.json({ error: 'Invitation already processed' }, { status: 400 })
}
if (status === 'accepted') {
const userData = await db
.select()
.from(user)
.where(eq(user.id, session.user.id))
.then((rows) => rows[0])
if (!userData || userData.email.toLowerCase() !== orgInvitation.email.toLowerCase()) {
return NextResponse.json(
{ error: 'Email mismatch. You can only accept invitations sent to your email address.' },
{ status: 403 }
)
}
}
if (status === 'cancelled') {
const isAdmin = await db
.select()
.from(member)
.where(
and(
eq(member.organizationId, organizationId),
eq(member.userId, session.user.id),
eq(member.role, 'admin')
)
)
.then((rows) => rows.length > 0)
if (!isAdmin) {
return NextResponse.json(
{ error: 'Only organization admins can cancel invitations' },
{ status: 403 }
)
}
}
await db.transaction(async (tx) => {
await tx.update(invitation).set({ status }).where(eq(invitation.id, invitationId))
if (status === 'accepted') {
await tx.insert(member).values({
id: randomUUID(),
userId: session.user.id,
organizationId,
role: orgInvitation.role,
createdAt: new Date(),
})
const linkedWorkspaceInvitations = await tx
.select()
.from(workspaceInvitation)
.where(
and(
eq(workspaceInvitation.orgInvitationId, invitationId),
eq(workspaceInvitation.status, 'pending' as WorkspaceInvitationStatus)
)
)
for (const wsInvitation of linkedWorkspaceInvitations) {
await tx
.update(workspaceInvitation)
.set({
status: 'accepted' as WorkspaceInvitationStatus,
updatedAt: new Date(),
})
.where(eq(workspaceInvitation.id, wsInvitation.id))
await tx.insert(permissions).values({
id: randomUUID(),
entityType: 'workspace',
entityId: wsInvitation.workspaceId,
userId: session.user.id,
permissionType: wsInvitation.permissions || 'read',
createdAt: new Date(),
updatedAt: new Date(),
})
}
} else if (status === 'cancelled') {
await tx
.update(workspaceInvitation)
.set({ status: 'cancelled' as WorkspaceInvitationStatus })
.where(eq(workspaceInvitation.orgInvitationId, invitationId))
}
})
logger.info(`Organization invitation ${status}`, {
organizationId,
invitationId,
userId: session.user.id,
email: orgInvitation.email,
})
return NextResponse.json({
success: true,
message: `Invitation ${status} successfully`,
invitation: { ...orgInvitation, status },
})
} catch (error) {
logger.error(`Error updating organization invitation:`, error)
return NextResponse.json({ error: 'Failed to update invitation' }, { status: 500 })
}
}

View File

@@ -1,5 +1,5 @@
import { randomUUID } from 'crypto'
import { and, eq, inArray } from 'drizzle-orm'
import { and, eq, inArray, isNull } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import {
getEmailSubject,
@@ -17,9 +17,17 @@ import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { hasWorkspaceAdminAccess } from '@/lib/permissions/utils'
import { db } from '@/db'
import { invitation, member, organization, user, workspace, workspaceInvitation } from '@/db/schema'
import {
invitation,
member,
organization,
user,
type WorkspaceInvitationStatus,
workspace,
workspaceInvitation,
} from '@/db/schema'
const logger = createLogger('OrganizationInvitationsAPI')
const logger = createLogger('OrganizationInvitations')
interface WorkspaceInvitation {
workspaceId: string
@@ -40,7 +48,6 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
const { id: organizationId } = await params
// Verify user has access to this organization
const memberEntry = await db
.select()
.from(member)
@@ -61,7 +68,6 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
return NextResponse.json({ error: 'Forbidden - Admin access required' }, { status: 403 })
}
// Get all pending invitations for the organization
const invitations = await db
.select({
id: invitation.id,
@@ -118,10 +124,8 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
const body = await request.json()
const { email, emails, role = 'member', workspaceInvitations } = body
// Handle single invitation vs batch
const invitationEmails = email ? [email] : emails
// Validate input
if (!invitationEmails || !Array.isArray(invitationEmails) || invitationEmails.length === 0) {
return NextResponse.json({ error: 'Email or emails array is required' }, { status: 400 })
}
@@ -130,7 +134,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
return NextResponse.json({ error: 'Invalid role' }, { status: 400 })
}
// Verify user has admin access
const memberEntry = await db
.select()
.from(member)
@@ -148,7 +151,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
return NextResponse.json({ error: 'Forbidden - Admin access required' }, { status: 403 })
}
// Handle validation-only requests
if (validateOnly) {
const validationResult = await validateBulkInvitations(organizationId, invitationEmails)
@@ -167,7 +169,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
})
}
// Validate seat availability
const seatValidation = await validateSeatAvailability(organizationId, invitationEmails.length)
if (!seatValidation.canInvite) {
@@ -185,7 +186,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
)
}
// Get organization details
const organizationEntry = await db
.select({ name: organization.name })
.from(organization)
@@ -196,7 +196,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
return NextResponse.json({ error: 'Organization not found' }, { status: 404 })
}
// Validate and normalize emails
const processedEmails = invitationEmails
.map((email: string) => {
const normalized = email.trim().toLowerCase()
@@ -209,11 +208,9 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
return NextResponse.json({ error: 'No valid emails provided' }, { status: 400 })
}
// Handle batch workspace invitations if provided
const validWorkspaceInvitations: WorkspaceInvitation[] = []
if (isBatch && workspaceInvitations && workspaceInvitations.length > 0) {
for (const wsInvitation of workspaceInvitations) {
// Check if user has admin permission on this workspace
const canInvite = await hasWorkspaceAdminAccess(session.user.id, wsInvitation.workspaceId)
if (!canInvite) {
@@ -229,7 +226,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
}
}
// Check for existing members
const existingMembers = await db
.select({ userEmail: user.email })
.from(member)
@@ -239,7 +235,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
const existingEmails = existingMembers.map((m) => m.userEmail)
const newEmails = processedEmails.filter((email: string) => !existingEmails.includes(email))
// Check for existing pending invitations
const existingInvitations = await db
.select({ email: invitation.email })
.from(invitation)
@@ -265,7 +260,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
)
}
// Create invitations
const expiresAt = new Date(Date.now() + 7 * 24 * 60 * 60 * 1000) // 7 days
const invitationsToCreate = emailsToInvite.map((email: string) => ({
id: randomUUID(),
@@ -280,10 +274,10 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
await db.insert(invitation).values(invitationsToCreate)
// Create workspace invitations if batch mode
const workspaceInvitationIds: string[] = []
if (isBatch && validWorkspaceInvitations.length > 0) {
for (const email of emailsToInvite) {
const orgInviteForEmail = invitationsToCreate.find((inv) => inv.email === email)
for (const wsInvitation of validWorkspaceInvitations) {
const wsInvitationId = randomUUID()
const token = randomUUID()
@@ -297,6 +291,7 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
status: 'pending',
token,
permissions: wsInvitation.permission,
orgInvitationId: orgInviteForEmail?.id,
expiresAt,
createdAt: new Date(),
updatedAt: new Date(),
@@ -307,7 +302,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
}
}
// Send invitation emails
const inviter = await db
.select({ name: user.name })
.from(user)
@@ -320,7 +314,6 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
let emailResult
if (isBatch && validWorkspaceInvitations.length > 0) {
// Get workspace details for batch email
const workspaceDetails = await db
.select({
id: workspace.id,
@@ -346,7 +339,7 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
organizationEntry[0]?.name || 'organization',
role,
workspaceInvitationsWithNames,
`${env.NEXT_PUBLIC_APP_URL}/api/organizations/invitations/accept?id=${orgInvitation.id}`
`${env.NEXT_PUBLIC_APP_URL}/invite/${orgInvitation.id}`
)
emailResult = await sendEmail({
@@ -359,7 +352,7 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
const emailHtml = await renderInvitationEmail(
inviter[0]?.name || 'Someone',
organizationEntry[0]?.name || 'organization',
`${env.NEXT_PUBLIC_APP_URL}/api/organizations/invitations/accept?id=${orgInvitation.id}`,
`${env.NEXT_PUBLIC_APP_URL}/invite/${orgInvitation.id}`,
email
)
@@ -446,7 +439,6 @@ export async function DELETE(
)
}
// Verify user has admin access
const memberEntry = await db
.select()
.from(member)
@@ -464,12 +456,9 @@ export async function DELETE(
return NextResponse.json({ error: 'Forbidden - Admin access required' }, { status: 403 })
}
// Cancel the invitation
const result = await db
.update(invitation)
.set({
status: 'cancelled',
})
.set({ status: 'cancelled' })
.where(
and(
eq(invitation.id, invitationId),
@@ -486,6 +475,23 @@ export async function DELETE(
)
}
await db
.update(workspaceInvitation)
.set({ status: 'cancelled' as WorkspaceInvitationStatus })
.where(eq(workspaceInvitation.orgInvitationId, invitationId))
await db
.update(workspaceInvitation)
.set({ status: 'cancelled' as WorkspaceInvitationStatus })
.where(
and(
isNull(workspaceInvitation.orgInvitationId),
eq(workspaceInvitation.email, result[0].email),
eq(workspaceInvitation.status, 'pending' as WorkspaceInvitationStatus),
eq(workspaceInvitation.inviterId, session.user.id)
)
)
logger.info('Organization invitation cancelled', {
organizationId,
invitationId,

View File

@@ -81,7 +81,6 @@ export async function GET(
.select({
currentPeriodCost: userStats.currentPeriodCost,
currentUsageLimit: userStats.currentUsageLimit,
usageLimitSetBy: userStats.usageLimitSetBy,
usageLimitUpdatedAt: userStats.usageLimitUpdatedAt,
lastPeriodCost: userStats.lastPeriodCost,
})
@@ -190,6 +189,11 @@ export async function PUT(
)
}
// Prevent admins from changing other admins' roles - only owners can modify admin roles
if (targetMember[0].role === 'admin' && userMember[0].role !== 'owner') {
return NextResponse.json({ error: 'Only owners can change admin roles' }, { status: 403 })
}
// Update member role
const updatedMember = await db
.update(member)

View File

@@ -75,7 +75,6 @@ export async function GET(request: NextRequest, { params }: { params: Promise<{
userEmail: user.email,
currentPeriodCost: userStats.currentPeriodCost,
currentUsageLimit: userStats.currentUsageLimit,
usageLimitSetBy: userStats.usageLimitSetBy,
usageLimitUpdatedAt: userStats.usageLimitUpdatedAt,
})
.from(member)
@@ -261,7 +260,7 @@ export async function POST(request: NextRequest, { params }: { params: Promise<{
const emailHtml = await renderInvitationEmail(
inviter[0]?.name || 'Someone',
organizationEntry[0]?.name || 'organization',
`${env.NEXT_PUBLIC_APP_URL}/api/organizations/invitations/accept?id=${invitationId}`,
`${env.NEXT_PUBLIC_APP_URL}/invite/organization?id=${invitationId}`,
normalizedEmail
)

View File

@@ -1,333 +0,0 @@
import { randomUUID } from 'crypto'
import { and, eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { env } from '@/lib/env'
import { createLogger } from '@/lib/logs/console/logger'
import { db } from '@/db'
import { invitation, member, permissions, workspaceInvitation } from '@/db/schema'
const logger = createLogger('OrganizationInvitationAcceptanceAPI')
// Accept an organization invitation and any associated workspace invitations
export async function GET(req: NextRequest) {
const invitationId = req.nextUrl.searchParams.get('id')
if (!invitationId) {
return NextResponse.redirect(
new URL(
'/invite/invite-error?reason=missing-invitation-id',
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
const session = await getSession()
if (!session?.user?.id) {
// Redirect to login, user will be redirected back after login
return NextResponse.redirect(
new URL(
`/invite/organization?id=${invitationId}`,
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
try {
// Find the organization invitation
const invitationResult = await db
.select()
.from(invitation)
.where(eq(invitation.id, invitationId))
.limit(1)
if (invitationResult.length === 0) {
return NextResponse.redirect(
new URL(
'/invite/invite-error?reason=invalid-invitation',
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
const orgInvitation = invitationResult[0]
// Check if invitation has expired
if (orgInvitation.expiresAt && new Date() > orgInvitation.expiresAt) {
return NextResponse.redirect(
new URL('/invite/invite-error?reason=expired', env.NEXT_PUBLIC_APP_URL || 'https://sim.ai')
)
}
// Check if invitation is still pending
if (orgInvitation.status !== 'pending') {
return NextResponse.redirect(
new URL(
'/invite/invite-error?reason=already-processed',
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
// Verify the email matches the current user
if (orgInvitation.email !== session.user.email) {
return NextResponse.redirect(
new URL(
'/invite/invite-error?reason=email-mismatch',
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
// Check if user is already a member of the organization
const existingMember = await db
.select()
.from(member)
.where(
and(
eq(member.organizationId, orgInvitation.organizationId),
eq(member.userId, session.user.id)
)
)
.limit(1)
if (existingMember.length > 0) {
return NextResponse.redirect(
new URL(
'/invite/invite-error?reason=already-member',
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
// Start transaction to accept both organization and workspace invitations
await db.transaction(async (tx) => {
// Accept organization invitation - add user as member
await tx.insert(member).values({
id: randomUUID(),
userId: session.user.id,
organizationId: orgInvitation.organizationId,
role: orgInvitation.role,
createdAt: new Date(),
})
// Mark organization invitation as accepted
await tx.update(invitation).set({ status: 'accepted' }).where(eq(invitation.id, invitationId))
// Find and accept any pending workspace invitations for the same email
const workspaceInvitations = await tx
.select()
.from(workspaceInvitation)
.where(
and(
eq(workspaceInvitation.email, orgInvitation.email),
eq(workspaceInvitation.status, 'pending')
)
)
for (const wsInvitation of workspaceInvitations) {
// Check if invitation hasn't expired
if (
wsInvitation.expiresAt &&
new Date().toISOString() <= wsInvitation.expiresAt.toISOString()
) {
// Check if user doesn't already have permissions on the workspace
const existingPermission = await tx
.select()
.from(permissions)
.where(
and(
eq(permissions.userId, session.user.id),
eq(permissions.entityType, 'workspace'),
eq(permissions.entityId, wsInvitation.workspaceId)
)
)
.limit(1)
if (existingPermission.length === 0) {
// Add workspace permissions
await tx.insert(permissions).values({
id: randomUUID(),
userId: session.user.id,
entityType: 'workspace',
entityId: wsInvitation.workspaceId,
permissionType: wsInvitation.permissions,
createdAt: new Date(),
updatedAt: new Date(),
})
// Mark workspace invitation as accepted
await tx
.update(workspaceInvitation)
.set({ status: 'accepted' })
.where(eq(workspaceInvitation.id, wsInvitation.id))
logger.info('Accepted workspace invitation', {
workspaceId: wsInvitation.workspaceId,
userId: session.user.id,
permission: wsInvitation.permissions,
})
}
}
}
})
logger.info('Successfully accepted batch invitation', {
organizationId: orgInvitation.organizationId,
userId: session.user.id,
role: orgInvitation.role,
})
// Redirect to success page or main app
return NextResponse.redirect(
new URL('/workspaces?invite=accepted', env.NEXT_PUBLIC_APP_URL || 'https://sim.ai')
)
} catch (error) {
logger.error('Failed to accept organization invitation', {
invitationId,
userId: session.user.id,
error,
})
return NextResponse.redirect(
new URL(
'/invite/invite-error?reason=server-error',
env.NEXT_PUBLIC_APP_URL || 'https://sim.ai'
)
)
}
}
// POST endpoint for programmatic acceptance (for API use)
export async function POST(req: NextRequest) {
const session = await getSession()
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
try {
const { invitationId } = await req.json()
if (!invitationId) {
return NextResponse.json({ error: 'Missing invitationId' }, { status: 400 })
}
// Similar logic to GET but return JSON response
const invitationResult = await db
.select()
.from(invitation)
.where(eq(invitation.id, invitationId))
.limit(1)
if (invitationResult.length === 0) {
return NextResponse.json({ error: 'Invalid invitation' }, { status: 404 })
}
const orgInvitation = invitationResult[0]
if (orgInvitation.expiresAt && new Date() > orgInvitation.expiresAt) {
return NextResponse.json({ error: 'Invitation expired' }, { status: 400 })
}
if (orgInvitation.status !== 'pending') {
return NextResponse.json({ error: 'Invitation already processed' }, { status: 400 })
}
if (orgInvitation.email !== session.user.email) {
return NextResponse.json({ error: 'Email mismatch' }, { status: 403 })
}
// Check if user is already a member
const existingMember = await db
.select()
.from(member)
.where(
and(
eq(member.organizationId, orgInvitation.organizationId),
eq(member.userId, session.user.id)
)
)
.limit(1)
if (existingMember.length > 0) {
return NextResponse.json({ error: 'Already a member' }, { status: 400 })
}
let acceptedWorkspaces = 0
// Accept invitations in transaction
await db.transaction(async (tx) => {
// Accept organization invitation
await tx.insert(member).values({
id: randomUUID(),
userId: session.user.id,
organizationId: orgInvitation.organizationId,
role: orgInvitation.role,
createdAt: new Date(),
})
await tx.update(invitation).set({ status: 'accepted' }).where(eq(invitation.id, invitationId))
// Accept workspace invitations
const workspaceInvitations = await tx
.select()
.from(workspaceInvitation)
.where(
and(
eq(workspaceInvitation.email, orgInvitation.email),
eq(workspaceInvitation.status, 'pending')
)
)
for (const wsInvitation of workspaceInvitations) {
if (
wsInvitation.expiresAt &&
new Date().toISOString() <= wsInvitation.expiresAt.toISOString()
) {
const existingPermission = await tx
.select()
.from(permissions)
.where(
and(
eq(permissions.userId, session.user.id),
eq(permissions.entityType, 'workspace'),
eq(permissions.entityId, wsInvitation.workspaceId)
)
)
.limit(1)
if (existingPermission.length === 0) {
await tx.insert(permissions).values({
id: randomUUID(),
userId: session.user.id,
entityType: 'workspace',
entityId: wsInvitation.workspaceId,
permissionType: wsInvitation.permissions,
createdAt: new Date(),
updatedAt: new Date(),
})
await tx
.update(workspaceInvitation)
.set({ status: 'accepted' })
.where(eq(workspaceInvitation.id, wsInvitation.id))
acceptedWorkspaces++
}
}
}
})
return NextResponse.json({
success: true,
message: `Successfully joined organization and ${acceptedWorkspaces} workspace(s)`,
organizationId: orgInvitation.organizationId,
workspacesJoined: acceptedWorkspaces,
})
} catch (error) {
logger.error('Failed to accept organization invitation via API', { error })
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
}
}

View File

@@ -0,0 +1,73 @@
import { NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { createOrganizationForTeamPlan } from '@/lib/billing/organization'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('CreateTeamOrganization')
export async function POST(request: Request) {
try {
const session = await getSession()
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized - no active session' }, { status: 401 })
}
const user = session.user
// Parse request body for optional name and slug
let organizationName = user.name
let organizationSlug: string | undefined
try {
const body = await request.json()
if (body.name && typeof body.name === 'string') {
organizationName = body.name
}
if (body.slug && typeof body.slug === 'string') {
organizationSlug = body.slug
}
} catch {
// If no body or invalid JSON, use defaults
}
logger.info('Creating organization for team plan', {
userId: user.id,
userName: user.name,
userEmail: user.email,
organizationName,
organizationSlug,
})
// Create organization and make user the owner/admin
const organizationId = await createOrganizationForTeamPlan(
user.id,
organizationName || undefined,
user.email,
organizationSlug
)
logger.info('Successfully created organization for team plan', {
userId: user.id,
organizationId,
})
return NextResponse.json({
success: true,
organizationId,
})
} catch (error) {
logger.error('Failed to create organization for team plan', {
error: error instanceof Error ? error.message : 'Unknown error',
stack: error instanceof Error ? error.stack : undefined,
})
return NextResponse.json(
{
error: 'Failed to create organization',
message: error instanceof Error ? error.message : 'Unknown error',
},
{ status: 500 }
)
}
}

View File

@@ -0,0 +1,46 @@
import { type NextRequest, NextResponse } from 'next/server'
import { createLogger } from '@/lib/logs/console/logger'
const logger = createLogger('OpenRouterModelsAPI')
export const dynamic = 'force-dynamic'
export async function GET(_request: NextRequest) {
try {
const response = await fetch('https://openrouter.ai/api/v1/models', {
headers: { 'Content-Type': 'application/json' },
cache: 'no-store',
})
if (!response.ok) {
logger.warn('Failed to fetch OpenRouter models', {
status: response.status,
statusText: response.statusText,
})
return NextResponse.json({ models: [] })
}
const data = await response.json()
const models = Array.isArray(data?.data)
? Array.from(
new Set(
data.data
.map((m: any) => m?.id)
.filter((id: unknown): id is string => typeof id === 'string' && id.length > 0)
.map((id: string) => `openrouter/${id}`)
)
)
: []
logger.info('Successfully fetched OpenRouter models', {
count: models.length,
})
return NextResponse.json({ models })
} catch (error) {
logger.error('Error fetching OpenRouter models', {
error: error instanceof Error ? error.message : 'Unknown error',
})
return NextResponse.json({ models: [] })
}
}

View File

@@ -1,5 +1,6 @@
import { type NextRequest, NextResponse } from 'next/server'
import { createLogger } from '@/lib/logs/console/logger'
import { validateImageUrl } from '@/lib/security/url-validation'
const logger = createLogger('ImageProxyAPI')
@@ -17,10 +18,18 @@ export async function GET(request: NextRequest) {
return new NextResponse('Missing URL parameter', { status: 400 })
}
const urlValidation = validateImageUrl(imageUrl)
if (!urlValidation.isValid) {
logger.warn(`[${requestId}] Blocked image proxy request`, {
url: imageUrl.substring(0, 100),
error: urlValidation.error,
})
return new NextResponse(urlValidation.error || 'Invalid image URL', { status: 403 })
}
logger.info(`[${requestId}] Proxying image request for: ${imageUrl}`)
try {
// Use fetch with custom headers that appear more browser-like
const imageResponse = await fetch(imageUrl, {
headers: {
'User-Agent':
@@ -45,10 +54,8 @@ export async function GET(request: NextRequest) {
})
}
// Get image content type from response headers
const contentType = imageResponse.headers.get('content-type') || 'image/jpeg'
// Get the image as a blob
const imageBlob = await imageResponse.blob()
if (imageBlob.size === 0) {
@@ -56,7 +63,6 @@ export async function GET(request: NextRequest) {
return new NextResponse('Empty image received', { status: 404 })
}
// Return the image with appropriate headers
return new NextResponse(imageBlob, {
headers: {
'Content-Type': contentType,

View File

@@ -1,6 +1,7 @@
import { NextResponse } from 'next/server'
import { isDev } from '@/lib/environment'
import { createLogger } from '@/lib/logs/console/logger'
import { validateProxyUrl } from '@/lib/security/url-validation'
import { executeTool } from '@/tools'
import { getTool, validateRequiredParametersAfterMerge } from '@/tools/utils'
@@ -80,6 +81,15 @@ export async function GET(request: Request) {
return createErrorResponse("Missing 'url' parameter", 400)
}
const urlValidation = validateProxyUrl(targetUrl)
if (!urlValidation.isValid) {
logger.warn(`[${requestId}] Blocked proxy request`, {
url: targetUrl.substring(0, 100),
error: urlValidation.error,
})
return createErrorResponse(urlValidation.error || 'Invalid URL', 403)
}
const method = url.searchParams.get('method') || 'GET'
const bodyParam = url.searchParams.get('body')
@@ -109,7 +119,6 @@ export async function GET(request: Request) {
logger.info(`[${requestId}] Proxying ${method} request to: ${targetUrl}`)
try {
// Forward the request to the target URL with all specified headers
const response = await fetch(targetUrl, {
method: method,
headers: {
@@ -119,7 +128,6 @@ export async function GET(request: Request) {
body: body || undefined,
})
// Get response data
const contentType = response.headers.get('content-type') || ''
let data
@@ -129,7 +137,6 @@ export async function GET(request: Request) {
data = await response.text()
}
// For error responses, include a more descriptive error message
const errorMessage = !response.ok
? data && typeof data === 'object' && data.error
? `${data.error.message || JSON.stringify(data.error)}`
@@ -140,7 +147,6 @@ export async function GET(request: Request) {
logger.error(`[${requestId}] External API error: ${response.status} ${response.statusText}`)
}
// Return the proxied response
return formatResponse({
success: response.ok,
status: response.status,
@@ -166,7 +172,6 @@ export async function POST(request: Request) {
const startTimeISO = startTime.toISOString()
try {
// Parse request body
let requestBody
try {
requestBody = await request.json()
@@ -186,7 +191,6 @@ export async function POST(request: Request) {
logger.info(`[${requestId}] Processing tool: ${toolId}`)
// Get tool
const tool = getTool(toolId)
if (!tool) {
@@ -194,7 +198,6 @@ export async function POST(request: Request) {
throw new Error(`Tool not found: ${toolId}`)
}
// Validate the tool and its parameters
try {
validateRequiredParametersAfterMerge(toolId, tool, params)
} catch (validationError) {
@@ -202,7 +205,6 @@ export async function POST(request: Request) {
error: validationError instanceof Error ? validationError.message : String(validationError),
})
// Add timing information even to error responses
const endTime = new Date()
const endTimeISO = endTime.toISOString()
const duration = endTime.getTime() - startTime.getTime()
@@ -214,14 +216,12 @@ export async function POST(request: Request) {
})
}
// Check if tool has file outputs - if so, don't skip post-processing
const hasFileOutputs =
tool.outputs &&
Object.values(tool.outputs).some(
(output) => output.type === 'file' || output.type === 'file[]'
)
// Execute tool
const result = await executeTool(
toolId,
params,

View File

@@ -64,7 +64,9 @@ export async function POST(request: Request) {
return new NextResponse(
`Internal Server Error: ${error instanceof Error ? error.message : 'Unknown error'}`,
{ status: 500 }
{
status: 500,
}
)
}
}

View File

@@ -112,7 +112,9 @@ export async function POST(request: NextRequest) {
return new Response(
`Internal Server Error: ${error instanceof Error ? error.message : 'Unknown error'}`,
{ status: 500 }
{
status: 500,
}
)
}
}

View File

@@ -4,6 +4,8 @@ import { NextResponse } from 'next/server'
import { v4 as uuidv4 } from 'uuid'
import { z } from 'zod'
import { checkServerSideUsageLimits } from '@/lib/billing'
import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription'
import { getPersonalAndWorkspaceEnv } from '@/lib/environment/utils'
import { createLogger } from '@/lib/logs/console/logger'
import { LoggingSession } from '@/lib/logs/execution/logging-session'
import { buildTraceSpans } from '@/lib/logs/execution/trace-spans/trace-spans'
@@ -17,13 +19,7 @@ import { decryptSecret } from '@/lib/utils'
import { loadWorkflowFromNormalizedTables } from '@/lib/workflows/db-helpers'
import { updateWorkflowRunCounts } from '@/lib/workflows/utils'
import { db } from '@/db'
import {
environment as environmentTable,
subscription,
userStats,
workflow,
workflowSchedule,
} from '@/db/schema'
import { userStats, workflow, workflowSchedule } from '@/db/schema'
import { Executor } from '@/executor'
import { Serializer } from '@/serializer'
import { RateLimiter } from '@/services/queue'
@@ -113,19 +109,15 @@ export async function GET() {
continue
}
// Check rate limits for scheduled execution
const [subscriptionRecord] = await db
.select({ plan: subscription.plan })
.from(subscription)
.where(eq(subscription.referenceId, workflowRecord.userId))
.limit(1)
// Check rate limits for scheduled execution (checks both personal and org subscriptions)
const userSubscription = await getHighestPrioritySubscription(workflowRecord.userId)
const subscriptionPlan = (subscriptionRecord?.plan || 'free') as SubscriptionPlan
const subscriptionPlan = (userSubscription?.plan || 'free') as SubscriptionPlan
const rateLimiter = new RateLimiter()
const rateLimitCheck = await rateLimiter.checkRateLimit(
const rateLimitCheck = await rateLimiter.checkRateLimitWithSubscription(
workflowRecord.userId,
subscriptionPlan,
userSubscription,
'schedule',
false // schedules are always sync
)
@@ -236,20 +228,15 @@ export async function GET() {
const mergedStates = mergeSubblockState(blocks)
// Retrieve environment variables for this user (if any).
const [userEnv] = await db
.select()
.from(environmentTable)
.where(eq(environmentTable.userId, workflowRecord.userId))
.limit(1)
if (!userEnv) {
logger.debug(
`[${requestId}] No environment record found for user ${workflowRecord.userId}. Proceeding with empty variables.`
)
}
const variables = EnvVarsSchema.parse(userEnv?.variables ?? {})
// Retrieve environment variables with workspace precedence
const { personalEncrypted, workspaceEncrypted } = await getPersonalAndWorkspaceEnv(
workflowRecord.userId,
workflowRecord.workspaceId || undefined
)
const variables = EnvVarsSchema.parse({
...personalEncrypted,
...workspaceEncrypted,
})
const currentBlockStates = await Object.entries(mergedStates).reduce(
async (accPromise, [id, block]) => {

View File

@@ -0,0 +1,114 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createMongoDBConnection, sanitizeCollectionName, validateFilter } from '../utils'
const logger = createLogger('MongoDBDeleteAPI')
const DeleteSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
authSource: z.string().optional(),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
collection: z.string().min(1, 'Collection name is required'),
filter: z
.union([z.string(), z.object({}).passthrough()])
.transform((val) => {
if (typeof val === 'object' && val !== null) {
return JSON.stringify(val)
}
return val
})
.refine((val) => val && val.trim() !== '' && val !== '{}', {
message: 'Filter is required for MongoDB Delete',
}),
multi: z
.union([z.boolean(), z.string(), z.undefined()])
.optional()
.transform((val) => {
if (val === 'true' || val === true) return true
if (val === 'false' || val === false) return false
return false // Default to false
}),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
let client = null
try {
const body = await request.json()
const params = DeleteSchema.parse(body)
logger.info(
`[${requestId}] Deleting document(s) from ${params.host}:${params.port}/${params.database}.${params.collection} (multi: ${params.multi})`
)
const sanitizedCollection = sanitizeCollectionName(params.collection)
const filterValidation = validateFilter(params.filter)
if (!filterValidation.isValid) {
logger.warn(`[${requestId}] Filter validation failed: ${filterValidation.error}`)
return NextResponse.json(
{ error: `Filter validation failed: ${filterValidation.error}` },
{ status: 400 }
)
}
let filterDoc
try {
filterDoc = JSON.parse(params.filter)
} catch (error) {
logger.warn(`[${requestId}] Invalid filter JSON: ${params.filter}`)
return NextResponse.json({ error: 'Invalid JSON format in filter' }, { status: 400 })
}
client = await createMongoDBConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
authSource: params.authSource,
ssl: params.ssl,
})
const db = client.db(params.database)
const coll = db.collection(sanitizedCollection)
let result
if (params.multi) {
result = await coll.deleteMany(filterDoc)
} else {
result = await coll.deleteOne(filterDoc)
}
logger.info(`[${requestId}] Delete completed: ${result.deletedCount} documents deleted`)
return NextResponse.json({
message: `${result.deletedCount} documents deleted`,
deletedCount: result.deletedCount,
})
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] MongoDB delete failed:`, error)
return NextResponse.json({ error: `MongoDB delete failed: ${errorMessage}` }, { status: 500 })
} finally {
if (client) {
await client.close()
}
}
}

View File

@@ -0,0 +1,102 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createMongoDBConnection, sanitizeCollectionName, validatePipeline } from '../utils'
const logger = createLogger('MongoDBExecuteAPI')
const ExecuteSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
authSource: z.string().optional(),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
collection: z.string().min(1, 'Collection name is required'),
pipeline: z
.union([z.string(), z.array(z.object({}).passthrough())])
.transform((val) => {
if (Array.isArray(val)) {
return JSON.stringify(val)
}
return val
})
.refine((val) => val && val.trim() !== '', {
message: 'Pipeline is required',
}),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
let client = null
try {
const body = await request.json()
const params = ExecuteSchema.parse(body)
logger.info(
`[${requestId}] Executing aggregation pipeline on ${params.host}:${params.port}/${params.database}.${params.collection}`
)
const sanitizedCollection = sanitizeCollectionName(params.collection)
const pipelineValidation = validatePipeline(params.pipeline)
if (!pipelineValidation.isValid) {
logger.warn(`[${requestId}] Pipeline validation failed: ${pipelineValidation.error}`)
return NextResponse.json(
{ error: `Pipeline validation failed: ${pipelineValidation.error}` },
{ status: 400 }
)
}
const pipelineDoc = JSON.parse(params.pipeline)
client = await createMongoDBConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
authSource: params.authSource,
ssl: params.ssl,
})
const db = client.db(params.database)
const coll = db.collection(sanitizedCollection)
const cursor = coll.aggregate(pipelineDoc)
const documents = await cursor.toArray()
logger.info(
`[${requestId}] Aggregation completed successfully, returned ${documents.length} documents`
)
return NextResponse.json({
message: `Aggregation completed, returned ${documents.length} documents`,
documents,
documentCount: documents.length,
})
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] MongoDB aggregation failed:`, error)
return NextResponse.json(
{ error: `MongoDB aggregation failed: ${errorMessage}` },
{ status: 500 }
)
} finally {
if (client) {
await client.close()
}
}
}

View File

@@ -0,0 +1,98 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createMongoDBConnection, sanitizeCollectionName } from '../utils'
const logger = createLogger('MongoDBInsertAPI')
const InsertSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
authSource: z.string().optional(),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
collection: z.string().min(1, 'Collection name is required'),
documents: z
.union([z.array(z.record(z.unknown())), z.string()])
.transform((val) => {
if (typeof val === 'string') {
try {
const parsed = JSON.parse(val)
return Array.isArray(parsed) ? parsed : [parsed]
} catch {
throw new Error('Invalid JSON in documents field')
}
}
return val
})
.refine((val) => Array.isArray(val) && val.length > 0, {
message: 'At least one document is required',
}),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
let client = null
try {
const body = await request.json()
const params = InsertSchema.parse(body)
logger.info(
`[${requestId}] Inserting ${params.documents.length} document(s) into ${params.host}:${params.port}/${params.database}.${params.collection}`
)
const sanitizedCollection = sanitizeCollectionName(params.collection)
client = await createMongoDBConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
authSource: params.authSource,
ssl: params.ssl,
})
const db = client.db(params.database)
const coll = db.collection(sanitizedCollection)
let result
if (params.documents.length === 1) {
result = await coll.insertOne(params.documents[0] as Record<string, unknown>)
logger.info(`[${requestId}] Single document inserted successfully`)
return NextResponse.json({
message: 'Document inserted successfully',
insertedId: result.insertedId.toString(),
documentCount: 1,
})
}
result = await coll.insertMany(params.documents as Record<string, unknown>[])
const insertedCount = Object.keys(result.insertedIds).length
logger.info(`[${requestId}] ${insertedCount} documents inserted successfully`)
return NextResponse.json({
message: `${insertedCount} documents inserted successfully`,
insertedIds: Object.values(result.insertedIds).map((id) => id.toString()),
documentCount: insertedCount,
})
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] MongoDB insert failed:`, error)
return NextResponse.json({ error: `MongoDB insert failed: ${errorMessage}` }, { status: 500 })
} finally {
if (client) {
await client.close()
}
}
}

View File

@@ -0,0 +1,136 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createMongoDBConnection, sanitizeCollectionName, validateFilter } from '../utils'
const logger = createLogger('MongoDBQueryAPI')
const QuerySchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
authSource: z.string().optional(),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
collection: z.string().min(1, 'Collection name is required'),
query: z
.union([z.string(), z.object({}).passthrough()])
.optional()
.default('{}')
.transform((val) => {
if (typeof val === 'object' && val !== null) {
return JSON.stringify(val)
}
return val || '{}'
}),
limit: z
.union([z.coerce.number().int().positive(), z.literal(''), z.undefined()])
.optional()
.transform((val) => {
if (val === '' || val === undefined || val === null) {
return 100
}
return val
}),
sort: z
.union([z.string(), z.object({}).passthrough(), z.null()])
.optional()
.transform((val) => {
if (typeof val === 'object' && val !== null) {
return JSON.stringify(val)
}
return val
}),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
let client = null
try {
const body = await request.json()
const params = QuerySchema.parse(body)
logger.info(
`[${requestId}] Executing MongoDB query on ${params.host}:${params.port}/${params.database}.${params.collection}`
)
const sanitizedCollection = sanitizeCollectionName(params.collection)
let filter = {}
if (params.query?.trim()) {
const validation = validateFilter(params.query)
if (!validation.isValid) {
logger.warn(`[${requestId}] Filter validation failed: ${validation.error}`)
return NextResponse.json(
{ error: `Filter validation failed: ${validation.error}` },
{ status: 400 }
)
}
filter = JSON.parse(params.query)
}
let sortCriteria = {}
if (params.sort?.trim()) {
try {
sortCriteria = JSON.parse(params.sort)
} catch (error) {
logger.warn(`[${requestId}] Invalid sort JSON: ${params.sort}`)
return NextResponse.json({ error: 'Invalid JSON format in sort criteria' }, { status: 400 })
}
}
client = await createMongoDBConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
authSource: params.authSource,
ssl: params.ssl,
})
const db = client.db(params.database)
const coll = db.collection(sanitizedCollection)
let cursor = coll.find(filter)
if (Object.keys(sortCriteria).length > 0) {
cursor = cursor.sort(sortCriteria)
}
const limit = params.limit || 100
cursor = cursor.limit(limit)
const documents = await cursor.toArray()
logger.info(
`[${requestId}] Query executed successfully, returned ${documents.length} documents`
)
return NextResponse.json({
message: `Found ${documents.length} documents`,
documents,
documentCount: documents.length,
})
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] MongoDB query failed:`, error)
return NextResponse.json({ error: `MongoDB query failed: ${errorMessage}` }, { status: 500 })
} finally {
if (client) {
await client.close()
}
}
}

View File

@@ -0,0 +1,143 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import { createMongoDBConnection, sanitizeCollectionName, validateFilter } from '../utils'
const logger = createLogger('MongoDBUpdateAPI')
const UpdateSchema = z.object({
host: z.string().min(1, 'Host is required'),
port: z.coerce.number().int().positive('Port must be a positive integer'),
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
authSource: z.string().optional(),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
collection: z.string().min(1, 'Collection name is required'),
filter: z
.union([z.string(), z.object({}).passthrough()])
.transform((val) => {
if (typeof val === 'object' && val !== null) {
return JSON.stringify(val)
}
return val
})
.refine((val) => val && val.trim() !== '' && val !== '{}', {
message: 'Filter is required for MongoDB Update',
}),
update: z
.union([z.string(), z.object({}).passthrough()])
.transform((val) => {
if (typeof val === 'object' && val !== null) {
return JSON.stringify(val)
}
return val
})
.refine((val) => val && val.trim() !== '', {
message: 'Update is required',
}),
upsert: z
.union([z.boolean(), z.string(), z.undefined()])
.optional()
.transform((val) => {
if (val === 'true' || val === true) return true
if (val === 'false' || val === false) return false
return false
}),
multi: z
.union([z.boolean(), z.string(), z.undefined()])
.optional()
.transform((val) => {
if (val === 'true' || val === true) return true
if (val === 'false' || val === false) return false
return false
}),
})
export async function POST(request: NextRequest) {
const requestId = randomUUID().slice(0, 8)
let client = null
try {
const body = await request.json()
const params = UpdateSchema.parse(body)
logger.info(
`[${requestId}] Updating document(s) in ${params.host}:${params.port}/${params.database}.${params.collection} (multi: ${params.multi}, upsert: ${params.upsert})`
)
const sanitizedCollection = sanitizeCollectionName(params.collection)
const filterValidation = validateFilter(params.filter)
if (!filterValidation.isValid) {
logger.warn(`[${requestId}] Filter validation failed: ${filterValidation.error}`)
return NextResponse.json(
{ error: `Filter validation failed: ${filterValidation.error}` },
{ status: 400 }
)
}
let filterDoc
let updateDoc
try {
filterDoc = JSON.parse(params.filter)
updateDoc = JSON.parse(params.update)
} catch (error) {
logger.warn(`[${requestId}] Invalid JSON in filter or update`)
return NextResponse.json(
{ error: 'Invalid JSON format in filter or update' },
{ status: 400 }
)
}
client = await createMongoDBConnection({
host: params.host,
port: params.port,
database: params.database,
username: params.username,
password: params.password,
authSource: params.authSource,
ssl: params.ssl,
})
const db = client.db(params.database)
const coll = db.collection(sanitizedCollection)
let result
if (params.multi) {
result = await coll.updateMany(filterDoc, updateDoc, { upsert: params.upsert })
} else {
result = await coll.updateOne(filterDoc, updateDoc, { upsert: params.upsert })
}
logger.info(
`[${requestId}] Update completed: ${result.modifiedCount} modified, ${result.matchedCount} matched${result.upsertedCount ? `, ${result.upsertedCount} upserted` : ''}`
)
return NextResponse.json({
message: `${result.modifiedCount} documents updated${result.upsertedCount ? `, ${result.upsertedCount} documents upserted` : ''}`,
matchedCount: result.matchedCount,
modifiedCount: result.modifiedCount,
documentCount: result.modifiedCount + (result.upsertedCount || 0),
...(result.upsertedId && { insertedId: result.upsertedId.toString() }),
})
} catch (error) {
if (error instanceof z.ZodError) {
logger.warn(`[${requestId}] Invalid request data`, { errors: error.errors })
return NextResponse.json(
{ error: 'Invalid request data', details: error.errors },
{ status: 400 }
)
}
const errorMessage = error instanceof Error ? error.message : 'Unknown error occurred'
logger.error(`[${requestId}] MongoDB update failed:`, error)
return NextResponse.json({ error: `MongoDB update failed: ${errorMessage}` }, { status: 500 })
} finally {
if (client) {
await client.close()
}
}
}

View File

@@ -0,0 +1,123 @@
import { MongoClient } from 'mongodb'
import type { MongoDBConnectionConfig } from '@/tools/mongodb/types'
export async function createMongoDBConnection(config: MongoDBConnectionConfig) {
const credentials =
config.username && config.password
? `${encodeURIComponent(config.username)}:${encodeURIComponent(config.password)}@`
: ''
const queryParams = new URLSearchParams()
if (config.authSource) {
queryParams.append('authSource', config.authSource)
}
if (config.ssl === 'required') {
queryParams.append('ssl', 'true')
}
const queryString = queryParams.toString()
const uri = `mongodb://${credentials}${config.host}:${config.port}/${config.database}${queryString ? `?${queryString}` : ''}`
const client = new MongoClient(uri, {
connectTimeoutMS: 10000,
socketTimeoutMS: 10000,
maxPoolSize: 1,
})
await client.connect()
return client
}
export function validateFilter(filter: string): { isValid: boolean; error?: string } {
try {
const parsed = JSON.parse(filter)
const dangerousOperators = ['$where', '$regex', '$expr', '$function', '$accumulator', '$let']
const checkForDangerousOps = (obj: any): boolean => {
if (typeof obj !== 'object' || obj === null) return false
for (const key of Object.keys(obj)) {
if (dangerousOperators.includes(key)) return true
if (typeof obj[key] === 'object' && checkForDangerousOps(obj[key])) return true
}
return false
}
if (checkForDangerousOps(parsed)) {
return {
isValid: false,
error: 'Filter contains potentially dangerous operators',
}
}
return { isValid: true }
} catch (error) {
return {
isValid: false,
error: 'Invalid JSON format in filter',
}
}
}
export function validatePipeline(pipeline: string): { isValid: boolean; error?: string } {
try {
const parsed = JSON.parse(pipeline)
if (!Array.isArray(parsed)) {
return {
isValid: false,
error: 'Pipeline must be an array',
}
}
const dangerousOperators = [
'$where',
'$function',
'$accumulator',
'$let',
'$merge',
'$out',
'$currentOp',
'$listSessions',
'$listLocalSessions',
]
const checkPipelineStage = (stage: any): boolean => {
if (typeof stage !== 'object' || stage === null) return false
for (const key of Object.keys(stage)) {
if (dangerousOperators.includes(key)) return true
if (typeof stage[key] === 'object' && checkPipelineStage(stage[key])) return true
}
return false
}
for (const stage of parsed) {
if (checkPipelineStage(stage)) {
return {
isValid: false,
error: 'Pipeline contains potentially dangerous operators',
}
}
}
return { isValid: true }
} catch (error) {
return {
isValid: false,
error: 'Invalid JSON format in pipeline',
}
}
}
export function sanitizeCollectionName(name: string): string {
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(name)) {
throw new Error(
'Invalid collection name. Must start with letter or underscore and contain only letters, numbers, and underscores.'
)
}
return name
}

View File

@@ -1,3 +1,4 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
@@ -11,13 +12,13 @@ const DeleteSchema = z.object({
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
table: z.string().min(1, 'Table name is required'),
where: z.string().min(1, 'WHERE clause is required'),
})
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()

View File

@@ -1,3 +1,4 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
@@ -11,12 +12,12 @@ const ExecuteSchema = z.object({
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
query: z.string().min(1, 'Query is required'),
})
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
@@ -26,7 +27,6 @@ export async function POST(request: NextRequest) {
`[${requestId}] Executing raw SQL on ${params.host}:${params.port}/${params.database}`
)
// Validate query before execution
const validation = validateQuery(params.query)
if (!validation.isValid) {
logger.warn(`[${requestId}] Query validation failed: ${validation.error}`)

View File

@@ -1,3 +1,4 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
@@ -11,7 +12,7 @@ const InsertSchema = z.object({
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
table: z.string().min(1, 'Table name is required'),
data: z.union([
z
@@ -38,13 +39,10 @@ const InsertSchema = z.object({
})
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
logger.info(`[${requestId}] Received data field type: ${typeof body.data}, value:`, body.data)
const params = InsertSchema.parse(body)
logger.info(

View File

@@ -1,3 +1,4 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
@@ -11,12 +12,12 @@ const QuerySchema = z.object({
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
query: z.string().min(1, 'Query is required'),
})
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
@@ -26,7 +27,6 @@ export async function POST(request: NextRequest) {
`[${requestId}] Executing MySQL query on ${params.host}:${params.port}/${params.database}`
)
// Validate query before execution
const validation = validateQuery(params.query)
if (!validation.isValid) {
logger.warn(`[${requestId}] Query validation failed: ${validation.error}`)

View File

@@ -1,3 +1,4 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
@@ -11,7 +12,7 @@ const UpdateSchema = z.object({
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
table: z.string().min(1, 'Table name is required'),
data: z.union([
z
@@ -36,7 +37,7 @@ const UpdateSchema = z.object({
})
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()

View File

@@ -6,7 +6,7 @@ export interface MySQLConnectionConfig {
database: string
username: string
password: string
ssl?: string
ssl?: 'disabled' | 'required' | 'preferred'
}
export async function createMySQLConnection(config: MySQLConnectionConfig) {
@@ -18,13 +18,13 @@ export async function createMySQLConnection(config: MySQLConnectionConfig) {
password: config.password,
}
// Handle SSL configuration
if (config.ssl === 'required') {
if (config.ssl === 'disabled') {
// Don't set ssl property at all to disable SSL
} else if (config.ssl === 'required') {
connectionConfig.ssl = { rejectUnauthorized: true }
} else if (config.ssl === 'preferred') {
connectionConfig.ssl = { rejectUnauthorized: false }
}
// For 'disabled', we don't set the ssl property at all
return mysql.createConnection(connectionConfig)
}
@@ -54,7 +54,6 @@ export async function executeQuery(
export function validateQuery(query: string): { isValid: boolean; error?: string } {
const trimmedQuery = query.trim().toLowerCase()
// Block dangerous SQL operations
const dangerousPatterns = [
/drop\s+database/i,
/drop\s+schema/i,
@@ -91,7 +90,6 @@ export function validateQuery(query: string): { isValid: boolean; error?: string
}
}
// Only allow specific statement types for execute endpoint
const allowedStatements = /^(select|insert|update|delete|with|show|describe|explain)\s+/i
if (!allowedStatements.test(trimmedQuery)) {
return {
@@ -116,6 +114,8 @@ export function buildInsertQuery(table: string, data: Record<string, unknown>) {
}
export function buildUpdateQuery(table: string, data: Record<string, unknown>, where: string) {
validateWhereClause(where)
const sanitizedTable = sanitizeIdentifier(table)
const columns = Object.keys(data)
const values = Object.values(data)
@@ -127,14 +127,33 @@ export function buildUpdateQuery(table: string, data: Record<string, unknown>, w
}
export function buildDeleteQuery(table: string, where: string) {
validateWhereClause(where)
const sanitizedTable = sanitizeIdentifier(table)
const query = `DELETE FROM ${sanitizedTable} WHERE ${where}`
return { query, values: [] }
}
function validateWhereClause(where: string): void {
const dangerousPatterns = [
/;\s*(drop|delete|insert|update|create|alter|grant|revoke)/i,
/union\s+select/i,
/into\s+outfile/i,
/load_file/i,
/--/,
/\/\*/,
/\*\//,
]
for (const pattern of dangerousPatterns) {
if (pattern.test(where)) {
throw new Error('WHERE clause contains potentially dangerous operation')
}
}
}
export function sanitizeIdentifier(identifier: string): string {
// Handle schema.table format
if (identifier.includes('.')) {
const parts = identifier.split('.')
return parts.map((part) => sanitizeSingleIdentifier(part)).join('.')
@@ -144,16 +163,13 @@ export function sanitizeIdentifier(identifier: string): string {
}
function sanitizeSingleIdentifier(identifier: string): string {
// Remove any existing backticks to prevent double-escaping
const cleaned = identifier.replace(/`/g, '')
// Validate identifier contains only safe characters
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(cleaned)) {
throw new Error(
`Invalid identifier: ${identifier}. Identifiers must start with a letter or underscore and contain only letters, numbers, and underscores.`
)
}
// Wrap in backticks for MySQL
return `\`${cleaned}\``
}

View File

@@ -1,11 +1,8 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import {
buildDeleteQuery,
createPostgresConnection,
executeQuery,
} from '@/app/api/tools/postgresql/utils'
import { createPostgresConnection, executeDelete } from '@/app/api/tools/postgresql/utils'
const logger = createLogger('PostgreSQLDeleteAPI')
@@ -15,13 +12,13 @@ const DeleteSchema = z.object({
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
table: z.string().min(1, 'Table name is required'),
where: z.string().min(1, 'WHERE clause is required'),
})
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
@@ -31,7 +28,7 @@ export async function POST(request: NextRequest) {
`[${requestId}] Deleting data from ${params.table} on ${params.host}:${params.port}/${params.database}`
)
const client = await createPostgresConnection({
const sql = createPostgresConnection({
host: params.host,
port: params.port,
database: params.database,
@@ -41,8 +38,7 @@ export async function POST(request: NextRequest) {
})
try {
const { query, values } = buildDeleteQuery(params.table, params.where)
const result = await executeQuery(client, query, values)
const result = await executeDelete(sql, params.table, params.where)
logger.info(`[${requestId}] Delete executed successfully, ${result.rowCount} row(s) deleted`)
@@ -52,7 +48,7 @@ export async function POST(request: NextRequest) {
rowCount: result.rowCount,
})
} finally {
await client.end()
await sql.end()
}
} catch (error) {
if (error instanceof z.ZodError) {

View File

@@ -1,3 +1,4 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
@@ -15,12 +16,12 @@ const ExecuteSchema = z.object({
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
query: z.string().min(1, 'Query is required'),
})
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
@@ -30,7 +31,6 @@ export async function POST(request: NextRequest) {
`[${requestId}] Executing raw SQL on ${params.host}:${params.port}/${params.database}`
)
// Validate query before execution
const validation = validateQuery(params.query)
if (!validation.isValid) {
logger.warn(`[${requestId}] Query validation failed: ${validation.error}`)
@@ -40,7 +40,7 @@ export async function POST(request: NextRequest) {
)
}
const client = await createPostgresConnection({
const sql = createPostgresConnection({
host: params.host,
port: params.port,
database: params.database,
@@ -50,7 +50,7 @@ export async function POST(request: NextRequest) {
})
try {
const result = await executeQuery(client, params.query)
const result = await executeQuery(sql, params.query)
logger.info(`[${requestId}] SQL executed successfully, ${result.rowCount} row(s) affected`)
@@ -60,7 +60,7 @@ export async function POST(request: NextRequest) {
rowCount: result.rowCount,
})
} finally {
await client.end()
await sql.end()
}
} catch (error) {
if (error instanceof z.ZodError) {

View File

@@ -1,11 +1,8 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import {
buildInsertQuery,
createPostgresConnection,
executeQuery,
} from '@/app/api/tools/postgresql/utils'
import { createPostgresConnection, executeInsert } from '@/app/api/tools/postgresql/utils'
const logger = createLogger('PostgreSQLInsertAPI')
@@ -15,7 +12,7 @@ const InsertSchema = z.object({
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
table: z.string().min(1, 'Table name is required'),
data: z.union([
z
@@ -42,21 +39,18 @@ const InsertSchema = z.object({
})
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
// Debug: Log the data field to see what we're getting
logger.info(`[${requestId}] Received data field type: ${typeof body.data}, value:`, body.data)
const params = InsertSchema.parse(body)
logger.info(
`[${requestId}] Inserting data into ${params.table} on ${params.host}:${params.port}/${params.database}`
)
const client = await createPostgresConnection({
const sql = createPostgresConnection({
host: params.host,
port: params.port,
database: params.database,
@@ -66,8 +60,7 @@ export async function POST(request: NextRequest) {
})
try {
const { query, values } = buildInsertQuery(params.table, params.data)
const result = await executeQuery(client, query, values)
const result = await executeInsert(sql, params.table, params.data)
logger.info(`[${requestId}] Insert executed successfully, ${result.rowCount} row(s) inserted`)
@@ -77,7 +70,7 @@ export async function POST(request: NextRequest) {
rowCount: result.rowCount,
})
} finally {
await client.end()
await sql.end()
}
} catch (error) {
if (error instanceof z.ZodError) {

View File

@@ -1,3 +1,4 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
@@ -11,12 +12,12 @@ const QuerySchema = z.object({
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
query: z.string().min(1, 'Query is required'),
})
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
@@ -26,7 +27,7 @@ export async function POST(request: NextRequest) {
`[${requestId}] Executing PostgreSQL query on ${params.host}:${params.port}/${params.database}`
)
const client = await createPostgresConnection({
const sql = createPostgresConnection({
host: params.host,
port: params.port,
database: params.database,
@@ -36,7 +37,7 @@ export async function POST(request: NextRequest) {
})
try {
const result = await executeQuery(client, params.query)
const result = await executeQuery(sql, params.query)
logger.info(`[${requestId}] Query executed successfully, returned ${result.rowCount} rows`)
@@ -46,7 +47,7 @@ export async function POST(request: NextRequest) {
rowCount: result.rowCount,
})
} finally {
await client.end()
await sql.end()
}
} catch (error) {
if (error instanceof z.ZodError) {

View File

@@ -1,11 +1,8 @@
import { randomUUID } from 'crypto'
import { type NextRequest, NextResponse } from 'next/server'
import { z } from 'zod'
import { createLogger } from '@/lib/logs/console/logger'
import {
buildUpdateQuery,
createPostgresConnection,
executeQuery,
} from '@/app/api/tools/postgresql/utils'
import { createPostgresConnection, executeUpdate } from '@/app/api/tools/postgresql/utils'
const logger = createLogger('PostgreSQLUpdateAPI')
@@ -15,7 +12,7 @@ const UpdateSchema = z.object({
database: z.string().min(1, 'Database name is required'),
username: z.string().min(1, 'Username is required'),
password: z.string().min(1, 'Password is required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('required'),
ssl: z.enum(['disabled', 'required', 'preferred']).default('preferred'),
table: z.string().min(1, 'Table name is required'),
data: z.union([
z
@@ -40,7 +37,7 @@ const UpdateSchema = z.object({
})
export async function POST(request: NextRequest) {
const requestId = crypto.randomUUID().slice(0, 8)
const requestId = randomUUID().slice(0, 8)
try {
const body = await request.json()
@@ -50,7 +47,7 @@ export async function POST(request: NextRequest) {
`[${requestId}] Updating data in ${params.table} on ${params.host}:${params.port}/${params.database}`
)
const client = await createPostgresConnection({
const sql = createPostgresConnection({
host: params.host,
port: params.port,
database: params.database,
@@ -60,8 +57,7 @@ export async function POST(request: NextRequest) {
})
try {
const { query, values } = buildUpdateQuery(params.table, params.data, params.where)
const result = await executeQuery(client, query, values)
const result = await executeUpdate(sql, params.table, params.data, params.where)
logger.info(`[${requestId}] Update executed successfully, ${result.rowCount} row(s) updated`)
@@ -71,7 +67,7 @@ export async function POST(request: NextRequest) {
rowCount: result.rowCount,
})
} finally {
await client.end()
await sql.end()
}
} catch (error) {
if (error instanceof z.ZodError) {

View File

@@ -1,43 +1,41 @@
import { Client } from 'pg'
import postgres from 'postgres'
import type { PostgresConnectionConfig } from '@/tools/postgresql/types'
export async function createPostgresConnection(config: PostgresConnectionConfig): Promise<Client> {
const client = new Client({
export function createPostgresConnection(config: PostgresConnectionConfig) {
const sslConfig =
config.ssl === 'disabled'
? false
: config.ssl === 'required'
? 'require'
: config.ssl === 'preferred'
? 'prefer'
: 'require'
const sql = postgres({
host: config.host,
port: config.port,
database: config.database,
user: config.username,
username: config.username,
password: config.password,
ssl:
config.ssl === 'disabled'
? false
: config.ssl === 'required'
? true
: config.ssl === 'preferred'
? { rejectUnauthorized: false }
: false,
connectionTimeoutMillis: 10000, // 10 seconds
query_timeout: 30000, // 30 seconds
ssl: sslConfig,
connect_timeout: 10, // 10 seconds
idle_timeout: 20, // 20 seconds
max_lifetime: 60 * 30, // 30 minutes
max: 1, // Single connection for tool usage
})
try {
await client.connect()
return client
} catch (error) {
await client.end()
throw error
}
return sql
}
export async function executeQuery(
client: Client,
sql: any,
query: string,
params: unknown[] = []
): Promise<{ rows: unknown[]; rowCount: number }> {
const result = await client.query(query, params)
const result = await sql.unsafe(query, params)
return {
rows: result.rows || [],
rowCount: result.rowCount || 0,
rows: Array.isArray(result) ? result : [result],
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
}
}
@@ -84,7 +82,6 @@ export function validateQuery(query: string): { isValid: boolean; error?: string
}
}
// Only allow specific statement types for execute endpoint
const allowedStatements = /^(select|insert|update|delete|with|explain|analyze|show)\s+/i
if (!allowedStatements.test(trimmedQuery)) {
return {
@@ -98,7 +95,6 @@ export function validateQuery(query: string): { isValid: boolean; error?: string
}
export function sanitizeIdentifier(identifier: string): string {
// Handle schema.table format
if (identifier.includes('.')) {
const parts = identifier.split('.')
return parts.map((part) => sanitizeSingleIdentifier(part)).join('.')
@@ -107,28 +103,41 @@ export function sanitizeIdentifier(identifier: string): string {
return sanitizeSingleIdentifier(identifier)
}
function validateWhereClause(where: string): void {
const dangerousPatterns = [
/;\s*(drop|delete|insert|update|create|alter|grant|revoke)/i,
/union\s+select/i,
/into\s+outfile/i,
/load_file/i,
/--/,
/\/\*/,
/\*\//,
]
for (const pattern of dangerousPatterns) {
if (pattern.test(where)) {
throw new Error('WHERE clause contains potentially dangerous operation')
}
}
}
function sanitizeSingleIdentifier(identifier: string): string {
// Remove any existing double quotes to prevent double-escaping
const cleaned = identifier.replace(/"/g, '')
// Validate identifier contains only safe characters
if (!/^[a-zA-Z_][a-zA-Z0-9_]*$/.test(cleaned)) {
throw new Error(
`Invalid identifier: ${identifier}. Identifiers must start with a letter or underscore and contain only letters, numbers, and underscores.`
)
}
// Wrap in double quotes for PostgreSQL
return `"${cleaned}"`
}
export function buildInsertQuery(
export async function executeInsert(
sql: any,
table: string,
data: Record<string, unknown>
): {
query: string
values: unknown[]
} {
): Promise<{ rows: unknown[]; rowCount: number }> {
const sanitizedTable = sanitizeIdentifier(table)
const columns = Object.keys(data)
const sanitizedColumns = columns.map((col) => sanitizeIdentifier(col))
@@ -136,18 +145,22 @@ export function buildInsertQuery(
const values = columns.map((col) => data[col])
const query = `INSERT INTO ${sanitizedTable} (${sanitizedColumns.join(', ')}) VALUES (${placeholders.join(', ')}) RETURNING *`
const result = await sql.unsafe(query, values)
return { query, values }
return {
rows: Array.isArray(result) ? result : [result],
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
}
}
export function buildUpdateQuery(
export async function executeUpdate(
sql: any,
table: string,
data: Record<string, unknown>,
where: string
): {
query: string
values: unknown[]
} {
): Promise<{ rows: unknown[]; rowCount: number }> {
validateWhereClause(where)
const sanitizedTable = sanitizeIdentifier(table)
const columns = Object.keys(data)
const sanitizedColumns = columns.map((col) => sanitizeIdentifier(col))
@@ -155,19 +168,27 @@ export function buildUpdateQuery(
const values = columns.map((col) => data[col])
const query = `UPDATE ${sanitizedTable} SET ${setClause} WHERE ${where} RETURNING *`
const result = await sql.unsafe(query, values)
return { query, values }
return {
rows: Array.isArray(result) ? result : [result],
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
}
}
export function buildDeleteQuery(
export async function executeDelete(
sql: any,
table: string,
where: string
): {
query: string
values: unknown[]
} {
): Promise<{ rows: unknown[]; rowCount: number }> {
validateWhereClause(where)
const sanitizedTable = sanitizeIdentifier(table)
const query = `DELETE FROM ${sanitizedTable} WHERE ${where} RETURNING *`
const result = await sql.unsafe(query, [])
return { query, values: [] }
return {
rows: Array.isArray(result) ? result : [result],
rowCount: Array.isArray(result) ? result.length : result ? 1 : 0,
}
}

View File

@@ -1,179 +0,0 @@
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { getUserUsageLimitInfo, updateUserUsageLimit } from '@/lib/billing'
import { updateMemberUsageLimit } from '@/lib/billing/core/organization-billing'
import { createLogger } from '@/lib/logs/console/logger'
import { isOrganizationOwnerOrAdmin } from '@/lib/permissions/utils'
const logger = createLogger('UnifiedUsageLimitsAPI')
/**
* Unified Usage Limits Endpoint
* GET/PUT /api/usage-limits?context=user|member&userId=<id>&organizationId=<id>
*
*/
export async function GET(request: NextRequest) {
const session = await getSession()
try {
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const { searchParams } = new URL(request.url)
const context = searchParams.get('context') || 'user'
const userId = searchParams.get('userId') || session.user.id
const organizationId = searchParams.get('organizationId')
// Validate context
if (!['user', 'member'].includes(context)) {
return NextResponse.json(
{ error: 'Invalid context. Must be "user" or "member"' },
{ status: 400 }
)
}
// For member context, require organizationId and check permissions
if (context === 'member') {
if (!organizationId) {
return NextResponse.json(
{ error: 'Organization ID is required when context=member' },
{ status: 400 }
)
}
// Check if the current user has permission to view member usage info
const hasPermission = await isOrganizationOwnerOrAdmin(session.user.id, organizationId)
if (!hasPermission) {
logger.warn('Unauthorized attempt to view member usage info', {
requesterId: session.user.id,
targetUserId: userId,
organizationId,
})
return NextResponse.json(
{
error:
'Permission denied. Only organization owners and admins can view member usage information',
},
{ status: 403 }
)
}
}
// For user context, ensure they can only view their own info
if (context === 'user' && userId !== session.user.id) {
return NextResponse.json(
{ error: "Cannot view other users' usage information" },
{ status: 403 }
)
}
// Get usage limit info
const usageLimitInfo = await getUserUsageLimitInfo(userId)
return NextResponse.json({
success: true,
context,
userId,
organizationId,
data: usageLimitInfo,
})
} catch (error) {
logger.error('Failed to get usage limit info', {
userId: session?.user?.id,
error,
})
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
}
}
export async function PUT(request: NextRequest) {
const session = await getSession()
try {
if (!session?.user?.id) {
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
}
const { searchParams } = new URL(request.url)
const context = searchParams.get('context') || 'user'
const userId = searchParams.get('userId') || session.user.id
const organizationId = searchParams.get('organizationId')
const { limit } = await request.json()
if (typeof limit !== 'number' || limit < 0) {
return NextResponse.json(
{ error: 'Invalid limit. Must be a positive number' },
{ status: 400 }
)
}
if (context === 'user') {
// Update user's own usage limit
if (userId !== session.user.id) {
return NextResponse.json({ error: "Cannot update other users' limits" }, { status: 403 })
}
await updateUserUsageLimit(userId, limit)
} else if (context === 'member') {
// Update organization member's usage limit
if (!organizationId) {
return NextResponse.json(
{ error: 'Organization ID is required when context=member' },
{ status: 400 }
)
}
// Check if the current user has permission to update member limits
const hasPermission = await isOrganizationOwnerOrAdmin(session.user.id, organizationId)
if (!hasPermission) {
logger.warn('Unauthorized attempt to update member usage limit', {
adminUserId: session.user.id,
targetUserId: userId,
organizationId,
})
return NextResponse.json(
{
error:
'Permission denied. Only organization owners and admins can update member usage limits',
},
{ status: 403 }
)
}
logger.info('Authorized member usage limit update', {
adminUserId: session.user.id,
targetUserId: userId,
organizationId,
newLimit: limit,
})
await updateMemberUsageLimit(organizationId, userId, limit, session.user.id)
} else {
return NextResponse.json(
{ error: 'Invalid context. Must be "user" or "member"' },
{ status: 400 }
)
}
// Return updated limit info
const updatedInfo = await getUserUsageLimitInfo(userId)
return NextResponse.json({
success: true,
context,
userId,
organizationId,
data: updatedInfo,
})
} catch (error) {
logger.error('Failed to update usage limit', {
userId: session?.user?.id,
error,
})
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
}
}

Some files were not shown because too many files have changed in this diff Show More