mirror of
https://github.com/simstudioai/sim.git
synced 2026-01-09 23:17:59 -05:00
Compare commits
259 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60a061e38a | ||
|
|
ab71fcfc49 | ||
|
|
864622c1dc | ||
|
|
8668622d66 | ||
|
|
53dd277cfe | ||
|
|
0e8e8c7a47 | ||
|
|
47da5eb6e8 | ||
|
|
37dcde2afc | ||
|
|
e31627c7c2 | ||
|
|
57c98d86ba | ||
|
|
0f7dfe084a | ||
|
|
afc1632830 | ||
|
|
56eee2c2d2 | ||
|
|
fc558a8eef | ||
|
|
c68cadfb84 | ||
|
|
95d93a2532 | ||
|
|
59b2023124 | ||
|
|
a672f17136 | ||
|
|
1de59668e4 | ||
|
|
26243b99e8 | ||
|
|
fce1423d05 | ||
|
|
3656d3d7ad | ||
|
|
581929bc01 | ||
|
|
11d8188415 | ||
|
|
36c98d18e9 | ||
|
|
0cf87e650d | ||
|
|
baef8d77f9 | ||
|
|
b74ab46820 | ||
|
|
533b4c53e0 | ||
|
|
c2d668c3eb | ||
|
|
1a5d5ddffa | ||
|
|
9de0d91f9a | ||
|
|
3db73ff721 | ||
|
|
9ffb48ee02 | ||
|
|
1f2a317ac2 | ||
|
|
a618d289d8 | ||
|
|
461d7b2342 | ||
|
|
4273161c0f | ||
|
|
54d42b33eb | ||
|
|
2c2c32c64b | ||
|
|
65e861822c | ||
|
|
12135d2aa8 | ||
|
|
f75c807580 | ||
|
|
9ea7ea79e9 | ||
|
|
5bbb349d8a | ||
|
|
ea09fcecb7 | ||
|
|
9ccb7600f9 | ||
|
|
ee17cf461a | ||
|
|
43cb124d97 | ||
|
|
76889fde26 | ||
|
|
7780d9b32b | ||
|
|
4a703a02cb | ||
|
|
a969d09782 | ||
|
|
0bc778130f | ||
|
|
df3d532495 | ||
|
|
f4f8fc051e | ||
|
|
76fac13f3d | ||
|
|
a3838302e0 | ||
|
|
4310dd6c15 | ||
|
|
813a0fb741 | ||
|
|
7e23e942d7 | ||
|
|
7fcbafab97 | ||
|
|
056dc2879c | ||
|
|
1aec32b7e2 | ||
|
|
316c9704af | ||
|
|
4e3a3bd1b1 | ||
|
|
36773e8cdb | ||
|
|
7ac89e35a1 | ||
|
|
faa094195a | ||
|
|
69319d21cd | ||
|
|
8362fd7a83 | ||
|
|
39ad793a9a | ||
|
|
921c755711 | ||
|
|
41ec75fcad | ||
|
|
f2502f5e48 | ||
|
|
f3c4f7e20a | ||
|
|
f578f43c9a | ||
|
|
5c73038023 | ||
|
|
92132024ca | ||
|
|
ed11456de3 | ||
|
|
8739a3d378 | ||
|
|
ca015deea9 | ||
|
|
fd6d927228 | ||
|
|
6ac59a3264 | ||
|
|
aa84c75360 | ||
|
|
ebb8cf8bf9 | ||
|
|
cadfcdbfbd | ||
|
|
7d62c200fa | ||
|
|
df646256b3 | ||
|
|
7c73f5ffe0 | ||
|
|
bb5f40a027 | ||
|
|
5ae5429296 | ||
|
|
fcf128f6db | ||
|
|
56543dafb4 | ||
|
|
7cc4574913 | ||
|
|
3f900947ce | ||
|
|
bda8ee772a | ||
|
|
104d34cc9e | ||
|
|
06e9a6b302 | ||
|
|
fed4e507cc | ||
|
|
389456e0f3 | ||
|
|
c720f23d9b | ||
|
|
89f7d2b943 | ||
|
|
923c05239c | ||
|
|
3424a338b7 | ||
|
|
51b1e97fa2 | ||
|
|
ab74b13802 | ||
|
|
861ab1446a | ||
|
|
e6f519a5a6 | ||
|
|
8226e7b40a | ||
|
|
b177b291cf | ||
|
|
9c3b43325b | ||
|
|
973a5c6497 | ||
|
|
78437c688e | ||
|
|
3b74250335 | ||
|
|
c68800c772 | ||
|
|
5403665fa9 | ||
|
|
3d3443f68e | ||
|
|
e5c0b14367 | ||
|
|
a495516901 | ||
|
|
1f9b4a8ef0 | ||
|
|
3372829c30 | ||
|
|
45372aece5 | ||
|
|
ed9b9ad83f | ||
|
|
766279bb8b | ||
|
|
1038e148c3 | ||
|
|
8b78200991 | ||
|
|
c8f4791582 | ||
|
|
6c9e0ec88b | ||
|
|
bbbf1c2941 | ||
|
|
efc487a845 | ||
|
|
5786909c5e | ||
|
|
833c5fefd5 | ||
|
|
79dd1ccb9f | ||
|
|
730164abee | ||
|
|
25b2c45ec0 | ||
|
|
780870c48e | ||
|
|
fdfa935a09 | ||
|
|
917552f041 | ||
|
|
4846f6c60d | ||
|
|
be810013c7 | ||
|
|
1ee4263e60 | ||
|
|
60c4668682 | ||
|
|
a268fb7c04 | ||
|
|
6c606750f5 | ||
|
|
e13adab14f | ||
|
|
44bc12b474 | ||
|
|
991f0442e9 | ||
|
|
2ebfb576ae | ||
|
|
11a7be54f2 | ||
|
|
f5219d03c3 | ||
|
|
f0643e01b4 | ||
|
|
77b0c5b9ed | ||
|
|
9dbd44e555 | ||
|
|
9ea9f2d52e | ||
|
|
4cd707fadb | ||
|
|
f0b07428bc | ||
|
|
8c9e182e10 | ||
|
|
33dd59f7a7 | ||
|
|
53ee9f99db | ||
|
|
0f2a125eae | ||
|
|
e107363ea7 | ||
|
|
7e364a7977 | ||
|
|
35a37d8b45 | ||
|
|
2b52d88cee | ||
|
|
abad3620a3 | ||
|
|
a37c6bc812 | ||
|
|
cd1bd95952 | ||
|
|
4c9fdbe7fb | ||
|
|
2c47cf4161 | ||
|
|
db1cf8a6db | ||
|
|
c6912095f7 | ||
|
|
154d9eef6a | ||
|
|
c2ded1f3e1 | ||
|
|
ff43528d35 | ||
|
|
692ba69864 | ||
|
|
cb7ce8659b | ||
|
|
5caef3a37d | ||
|
|
a6888da124 | ||
|
|
07b0597f4f | ||
|
|
71e2994f9d | ||
|
|
9973b2c165 | ||
|
|
d9e5777538 | ||
|
|
dd74267313 | ||
|
|
1db72dc823 | ||
|
|
da707fa491 | ||
|
|
9ffaf305bd | ||
|
|
26e6286fda | ||
|
|
c795fc83aa | ||
|
|
cea42f5135 | ||
|
|
6fd6f921dc | ||
|
|
7530fb9a4e | ||
|
|
9a5b035822 | ||
|
|
0c0b6bf967 | ||
|
|
5d74db53ff | ||
|
|
b39bdfd55e | ||
|
|
6b185be9a4 | ||
|
|
214a0358b6 | ||
|
|
bbb5e53e43 | ||
|
|
79e932fed9 | ||
|
|
9ad36c0e34 | ||
|
|
2771c688ff | ||
|
|
d58ceb4bce | ||
|
|
69773c3174 | ||
|
|
1619d63f2a | ||
|
|
9aa1fe8037 | ||
|
|
1b7c111c46 | ||
|
|
bdfb56b262 | ||
|
|
4a7de31eee | ||
|
|
adfe56c720 | ||
|
|
72e3efa875 | ||
|
|
b40fa3aa6e | ||
|
|
f924edde3a | ||
|
|
073030bfaa | ||
|
|
871f4e8e18 | ||
|
|
091343a132 | ||
|
|
63c66bfc31 | ||
|
|
445ca78395 | ||
|
|
d75cc1ed84 | ||
|
|
5a8a703ecb | ||
|
|
6f64188b8d | ||
|
|
60a9a25553 | ||
|
|
52fa388f81 | ||
|
|
5c56cbd558 | ||
|
|
dc19525a6f | ||
|
|
3873f44875 | ||
|
|
09b95f41ea | ||
|
|
af60ccd188 | ||
|
|
eb75afd115 | ||
|
|
fdb8256468 | ||
|
|
570c07bf2a | ||
|
|
5c16e7d390 | ||
|
|
bd38062705 | ||
|
|
d7fd4a9618 | ||
|
|
d972bab206 | ||
|
|
f254d70624 | ||
|
|
8748e1d5f9 | ||
|
|
133a32e6d3 | ||
|
|
97b6bcc43d | ||
|
|
42917ce641 | ||
|
|
5f6d219223 | ||
|
|
bab74307f4 | ||
|
|
16aaa37dad | ||
|
|
c6166a9483 | ||
|
|
0258a1b4ce | ||
|
|
4d4aefa346 | ||
|
|
a0cf003abf | ||
|
|
2e027dd77d | ||
|
|
6133db53d0 | ||
|
|
03bb437e09 | ||
|
|
9f02f88bf5 | ||
|
|
7a1711282e | ||
|
|
58613888b0 | ||
|
|
f1fe2f52cc | ||
|
|
7d05999a70 | ||
|
|
bf07240cfa | ||
|
|
0c7a8efc8d | ||
|
|
f081f5a73c | ||
|
|
72c07e8ad2 |
@@ -77,7 +77,7 @@ services:
|
||||
- POSTGRES_PASSWORD=postgres
|
||||
- POSTGRES_DB=simstudio
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "${POSTGRES_PORT:-5432}:5432"
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
interval: 5s
|
||||
|
||||
22
.github/workflows/build.yml
vendored
22
.github/workflows/build.yml
vendored
@@ -2,8 +2,7 @@ name: Build and Publish Docker Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
tags: ['v*']
|
||||
branches: [main, staging]
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
@@ -56,7 +55,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to the Container registry
|
||||
if: github.event_name != 'pull_request'
|
||||
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
@@ -70,10 +69,7 @@ jobs:
|
||||
images: ${{ matrix.image }}
|
||||
tags: |
|
||||
type=raw,value=latest-${{ matrix.arch }},enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=ref,event=pr,suffix=-${{ matrix.arch }}
|
||||
type=semver,pattern={{version}},suffix=-${{ matrix.arch }}
|
||||
type=semver,pattern={{major}}.{{minor}},suffix=-${{ matrix.arch }}
|
||||
type=semver,pattern={{major}}.{{minor}}.{{patch}},suffix=-${{ matrix.arch }}
|
||||
type=raw,value=staging-${{ github.sha }}-${{ matrix.arch }},enable=${{ github.ref == 'refs/heads/staging' }}
|
||||
type=sha,format=long,suffix=-${{ matrix.arch }}
|
||||
|
||||
- name: Build and push Docker image
|
||||
@@ -82,18 +78,18 @@ jobs:
|
||||
context: .
|
||||
file: ${{ matrix.dockerfile }}
|
||||
platforms: ${{ matrix.platform }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
push: ${{ github.event_name != 'pull_request' && github.ref == 'refs/heads/main' }}
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
cache-from: type=gha,scope=build-v2
|
||||
cache-to: type=gha,mode=max,scope=build-v2
|
||||
cache-from: type=gha,scope=build-v3
|
||||
cache-to: type=gha,mode=max,scope=build-v3
|
||||
provenance: false
|
||||
sbom: false
|
||||
|
||||
create-manifests:
|
||||
runs-on: ubuntu-latest
|
||||
needs: build-and-push
|
||||
if: github.event_name != 'pull_request'
|
||||
if: github.event_name != 'pull_request' && github.ref == 'refs/heads/main'
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
@@ -119,10 +115,6 @@ jobs:
|
||||
images: ${{ matrix.image }}
|
||||
tags: |
|
||||
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||
type=ref,event=pr
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}.{{minor}}.{{patch}}
|
||||
type=sha,format=long
|
||||
|
||||
- name: Create and push manifest
|
||||
|
||||
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
node-version: latest
|
||||
|
||||
- name: Install dependencies
|
||||
run: bun install
|
||||
run: bun install --frozen-lockfile
|
||||
|
||||
- name: Run tests with coverage
|
||||
env:
|
||||
|
||||
44
.github/workflows/trigger-deploy.yml
vendored
Normal file
44
.github/workflows/trigger-deploy.yml
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
name: Trigger.dev Deploy
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- staging
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
name: Trigger.dev Deploy
|
||||
runs-on: ubuntu-latest
|
||||
concurrency:
|
||||
group: trigger-deploy-${{ github.ref }}
|
||||
cancel-in-progress: false
|
||||
env:
|
||||
TRIGGER_ACCESS_TOKEN: ${{ secrets.TRIGGER_ACCESS_TOKEN }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version: latest
|
||||
|
||||
- name: Install dependencies
|
||||
run: bun install
|
||||
|
||||
- name: Deploy to Staging
|
||||
if: github.ref == 'refs/heads/staging'
|
||||
working-directory: ./apps/sim
|
||||
run: npx --yes trigger.dev@4.0.1 deploy -e staging
|
||||
|
||||
- name: Deploy to Production
|
||||
if: github.ref == 'refs/heads/main'
|
||||
working-directory: ./apps/sim
|
||||
run: npx --yes trigger.dev@4.0.1 deploy
|
||||
|
||||
57
README.md
57
README.md
@@ -1,50 +1,46 @@
|
||||
<p align="center">
|
||||
<img src="apps/sim/public/static/sim.png" alt="Sim Logo" width="500"/>
|
||||
<a href="https://sim.ai" target="_blank" rel="noopener noreferrer">
|
||||
<img src="apps/sim/public/logo/reverse/text/large.png" alt="Sim Logo" width="500"/>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://www.apache.org/licenses/LICENSE-2.0"><img src="https://img.shields.io/badge/License-Apache%202.0-blue.svg" alt="License: Apache-2.0"></a>
|
||||
<a href="https://discord.gg/Hr4UWYEcTT"><img src="https://img.shields.io/badge/Discord-Join%20Server-7289DA?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://x.com/simdotai"><img src="https://img.shields.io/twitter/follow/simstudioai?style=social" alt="Twitter"></a>
|
||||
<a href="https://github.com/simstudioai/sim/pulls"><img src="https://img.shields.io/badge/PRs-welcome-brightgreen.svg" alt="PRs welcome"></a>
|
||||
<a href="https://docs.sim.ai"><img src="https://img.shields.io/badge/Docs-visit%20documentation-blue.svg" alt="Documentation"></a>
|
||||
</p>
|
||||
<p align="center">Build and deploy AI agent workflows in minutes.</p>
|
||||
|
||||
<p align="center">
|
||||
<strong>Sim</strong> is a lightweight, user-friendly platform for building AI agent workflows.
|
||||
<a href="https://sim.ai" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/sim.ai-6F3DFA" alt="Sim.ai"></a>
|
||||
<a href="https://discord.gg/Hr4UWYEcTT" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Discord-Join%20Server-5865F2?logo=discord&logoColor=white" alt="Discord"></a>
|
||||
<a href="https://x.com/simdotai" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/twitter/follow/simstudioai?style=social" alt="Twitter"></a>
|
||||
<a href="https://docs.sim.ai" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/Docs-6F3DFA.svg" alt="Documentation"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<img src="apps/sim/public/static/demo.gif" alt="Sim Demo" width="800"/>
|
||||
</p>
|
||||
|
||||
## Getting Started
|
||||
## Quickstart
|
||||
|
||||
1. Use our [cloud-hosted version](https://sim.ai)
|
||||
2. Self-host using one of the methods below
|
||||
### Cloud-hosted: [sim.ai](https://sim.ai)
|
||||
|
||||
## Self-Hosting Options
|
||||
<a href="https://sim.ai" target="_blank" rel="noopener noreferrer"><img src="https://img.shields.io/badge/sim.ai-6F3DFA?logo=data:image/svg%2bxml;base64,PHN2ZyB3aWR0aD0iNjE2IiBoZWlnaHQ9IjYxNiIgdmlld0JveD0iMCAwIDYxNiA2MTYiIGZpbGw9Im5vbmUiIHhtbG5zPSJodHRwOi8vd3d3LnczLm9yZy8yMDAwL3N2ZyI+CjxnIGNsaXAtcGF0aD0idXJsKCNjbGlwMF8xMTU5XzMxMykiPgo8cGF0aCBkPSJNNjE2IDBIMFY2MTZINjE2VjBaIiBmaWxsPSIjNkYzREZBIi8+CjxwYXRoIGQ9Ik04MyAzNjUuNTY3SDExM0MxMTMgMzczLjgwNSAxMTYgMzgwLjM3MyAxMjIgMzg1LjI3MkMxMjggMzg5Ljk0OCAxMzYuMTExIDM5Mi4yODUgMTQ2LjMzMyAzOTIuMjg1QzE1Ny40NDQgMzkyLjI4NSAxNjYgMzkwLjE3MSAxNzIgMzg1LjkzOUMxNzcuOTk5IDM4MS40ODcgMTgxIDM3NS41ODYgMTgxIDM2OC4yMzlDMTgxIDM2Mi44OTUgMTc5LjMzMyAzNTguNDQyIDE3NiAzNTQuODhDMTcyLjg4OSAzNTEuMzE4IDE2Ny4xMTEgMzQ4LjQyMiAxNTguNjY3IDM0Ni4xOTZMMTMwIDMzOS41MTdDMTE1LjU1NSAzMzUuOTU1IDEwNC43NzggMzMwLjQ5OSA5Ny42NjY1IDMyMy4xNTFDOTAuNzc3NSAzMTUuODA0IDg3LjMzMzQgMzA2LjExOSA4Ny4zMzM0IDI5NC4wOTZDODcuMzMzNCAyODQuMDc2IDg5Ljg4OSAyNzUuMzkyIDk0Ljk5OTYgMjY4LjA0NUMxMDAuMzMzIDI2MC42OTcgMTA3LjU1NSAyNTUuMDIgMTE2LjY2NiAyNTEuMDEyQzEyNiAyNDcuMDA0IDEzNi42NjcgMjQ1IDE0OC42NjYgMjQ1QzE2MC42NjcgMjQ1IDE3MSAyNDcuMTE2IDE3OS42NjcgMjUxLjM0NkMxODguNTU1IDI1NS41NzYgMTk1LjQ0NCAyNjEuNDc3IDIwMC4zMzMgMjY5LjA0N0MyMDUuNDQ0IDI3Ni42MTcgMjA4LjExMSAyODUuNjM0IDIwOC4zMzMgMjk2LjA5OUgxNzguMzMzQzE3OC4xMTEgMjg3LjYzOCAxNzUuMzMzIDI4MS4wNyAxNjkuOTk5IDI3Ni4zOTRDMTY0LjY2NiAyNzEuNzE5IDE1Ny4yMjIgMjY5LjM4MSAxNDcuNjY3IDI2OS4zODFDMTM3Ljg4OSAyNjkuMzgxIDEzMC4zMzMgMjcxLjQ5NiAxMjUgMjc1LjcyNkMxMTkuNjY2IDI3OS45NTcgMTE3IDI4NS43NDYgMTE3IDI5My4wOTNDMTE3IDMwNC4wMDMgMTI1IDMxMS40NjIgMTQxIDMxNS40N0wxNjkuNjY3IDMyMi40ODNDMTgzLjQ0NSAzMjUuNiAxOTMuNzc4IDMzMC43MjIgMjAwLjY2NyAzMzcuODQ3QzIwNy41NTUgMzQ0Ljc0OSAyMTEgMzU0LjIxMiAyMTEgMzY2LjIzNUMyMTEgMzc2LjQ3NyAyMDguMjIyIDM4NS40OTQgMjAyLjY2NiAzOTMuMjg3QzE5Ny4xMTEgNDAwLjg1NyAxODkuNDQ0IDQwNi43NTggMTc5LjY2NyA0MTAuOTg5QzE3MC4xMTEgNDE0Ljk5NiAxNTguNzc4IDQxNyAxNDUuNjY3IDQxN0MxMjYuNTU1IDQxNyAxMTEuMzMzIDQxMi4zMjUgOTkuOTk5NyA0MDIuOTczQzg4LjY2NjggMzkzLjYyMSA4MyAzODEuMTUzIDgzIDM2NS41NjdaIiBmaWxsPSJ3aGl0ZSIvPgo8cGF0aCBkPSJNMjMyLjI5MSA0MTNWMjUwLjA4MkMyNDQuNjg0IDI1NC42MTQgMjUwLjE0OCAyNTQuNjE0IDI2My4zNzEgMjUwLjA4MlY0MTNIMjMyLjI5MVpNMjQ3LjUgMjM5LjMxM0MyNDEuOTkgMjM5LjMxMyAyMzcuMTQgMjM3LjMxMyAyMzIuOTUyIDIzMy4zMTZDMjI4Ljk4NCAyMjkuMDk1IDIyNyAyMjQuMjA5IDIyNyAyMTguNjU2QzIyNyAyMTIuODgyIDIyOC45ODQgMjA3Ljk5NSAyMzIuOTUyIDIwMy45OTdDMjM3LjE0IDE5OS45OTkgMjQxLjk5IDE5OCAyNDcuNSAxOThDMjUzLjIzMSAxOTggMjU4LjA4IDE5OS45OTkgMjYyLjA0OSAyMDMuOTk3QzI2Ni4wMTYgMjA3Ljk5NSAyNjggMjEyLjg4MiAyNjggMjE4LjY1NkMyNjggMjI0LjIwOSAyNjYuMDE2IDIyOS4wOTUgMjYyLjA0OSAyMzMuMzE2QzI1OC4wOCAyMzcuMzEzIDI1My4yMzEgMjM5LjMxMyAyNDcuNSAyMzkuMzEzWiIgZmlsbD0id2hpdGUiLz4KPHBhdGggZD0iTTMxOS4zMzMgNDEzSDI4OFYyNDkuNjc2SDMxNlYyNzcuMjMzQzMxOS4zMzMgMjY4LjEwNCAzMjUuNzc4IDI2MC4zNjQgMzM0LjY2NyAyNTQuMzUyQzM0My43NzggMjQ4LjExNyAzNTQuNzc4IDI0NSAzNjcuNjY3IDI0NUMzODIuMTExIDI0NSAzOTQuMTEyIDI0OC44OTcgNDAzLjY2NyAyNTYuNjlDNDEzLjIyMiAyNjQuNDg0IDQxOS40NDQgMjc0LjgzNyA0MjIuMzM0IDI4Ny43NTJINDE2LjY2N0M0MTguODg5IDI3NC44MzcgNDI1IDI2NC40ODQgNDM1IDI1Ni42OUM0NDUgMjQ4Ljg5NyA0NTcuMzM0IDI0NSA0NzIgMjQ1QzQ5MC42NjYgMjQ1IDUwNS4zMzQgMjUwLjQ1NSA1MTYgMjYxLjM2NkM1MjYuNjY3IDI3Mi4yNzYgNTMyIDI4Ny4xOTUgNTMyIDMwNi4xMjFWNDEzSDUwMS4zMzNWMzEzLjgwNEM1MDEuMzMzIDMwMC44ODkgNDk4IDI5MC45ODEgNDkxLjMzMyAyODQuMDc4QzQ4NC44ODkgMjc2Ljk1MiA0NzYuMTExIDI3My4zOSA0NjUgMjczLjM5QzQ1Ny4yMjIgMjczLjM5IDQ1MC4zMzMgMjc1LjE3MSA0NDQuMzM0IDI3OC43MzRDNDM4LjU1NiAyODIuMDc0IDQzNCAyODYuOTcyIDQzMC42NjcgMjkzLjQzQzQyNy4zMzMgMjk5Ljg4NyA0MjUuNjY3IDMwNy40NTcgNDI1LjY2NyAzMTYuMTQxVjQxM0gzOTQuNjY3VjMxMy40NjlDMzk0LjY2NyAzMDAuNTU1IDM5MS40NDUgMjkwLjc1OCAzODUgMjg0LjA3OEMzNzguNTU2IDI3Ny4xNzUgMzY5Ljc3OCAyNzMuNzI0IDM1OC42NjcgMjczLjcyNEMzNTAuODg5IDI3My43MjQgMzQ0IDI3NS41MDUgMzM4IDI3OS4wNjhDMzMyLjIyMiAyODIuNDA4IDMyNy42NjcgMjg3LjMwNyAzMjQuMzMzIDI5My43NjNDMzIxIDI5OS45OTggMzE5LjMzMyAzMDcuNDU3IDMxOS4zMzMgMzE2LjE0MVY0MTNaIiBmaWxsPSJ3aGl0ZSIvPgo8L2c+CjxkZWZzPgo8Y2xpcFBhdGggaWQ9ImNsaXAwXzExNTlfMzEzIj4KPHJlY3Qgd2lkdGg9IjYxNiIgaGVpZ2h0PSI2MTYiIGZpbGw9IndoaXRlIi8+CjwvY2xpcFBhdGg+CjwvZGVmcz4KPC9zdmc+Cg==&logoColor=white" alt="Sim.ai"></a>
|
||||
|
||||
### Option 1: NPM Package (Simplest)
|
||||
|
||||
The easiest way to run Sim locally is using our [NPM package](https://www.npmjs.com/package/simstudio?activeTab=readme):
|
||||
### Self-hosted: NPM Package
|
||||
|
||||
```bash
|
||||
npx simstudio
|
||||
```
|
||||
→ http://localhost:3000
|
||||
|
||||
After running these commands, open [http://localhost:3000/](http://localhost:3000/) in your browser.
|
||||
#### Note
|
||||
Docker must be installed and running on your machine.
|
||||
|
||||
#### Options
|
||||
|
||||
- `-p, --port <port>`: Specify the port to run Sim on (default: 3000)
|
||||
- `--no-pull`: Skip pulling the latest Docker images
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `-p, --port <port>` | Port to run Sim on (default `3000`) |
|
||||
| `--no-pull` | Skip pulling latest Docker images |
|
||||
|
||||
#### Requirements
|
||||
|
||||
- Docker must be installed and running on your machine
|
||||
|
||||
### Option 2: Docker Compose
|
||||
### Self-hosted: Docker Compose
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
@@ -76,14 +72,14 @@ Wait for the model to download, then visit [http://localhost:3000](http://localh
|
||||
docker compose -f docker-compose.ollama.yml exec ollama ollama pull llama3.1:8b
|
||||
```
|
||||
|
||||
### Option 3: Dev Containers
|
||||
### Self-hosted: Dev Containers
|
||||
|
||||
1. Open VS Code with the [Remote - Containers extension](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers)
|
||||
2. Open the project and click "Reopen in Container" when prompted
|
||||
3. Run `bun run dev:full` in the terminal or use the `sim-start` alias
|
||||
- This starts both the main application and the realtime socket server
|
||||
|
||||
### Option 4: Manual Setup
|
||||
### Self-hosted: Manual Setup
|
||||
|
||||
**Requirements:**
|
||||
- [Bun](https://bun.sh/) runtime
|
||||
@@ -158,6 +154,13 @@ cd apps/sim
|
||||
bun run dev:sockets
|
||||
```
|
||||
|
||||
## Copilot API Keys
|
||||
|
||||
Copilot is a Sim-managed service. To use Copilot on a self-hosted instance:
|
||||
|
||||
- Go to https://sim.ai → Settings → Copilot and generate a Copilot API key
|
||||
- Set `COPILOT_API_KEY` in your self-hosted environment to that value
|
||||
|
||||
## Tech Stack
|
||||
|
||||
- **Framework**: [Next.js](https://nextjs.org/) (App Router)
|
||||
@@ -180,4 +183,4 @@ We welcome contributions! Please see our [Contributing Guide](.github/CONTRIBUTI
|
||||
|
||||
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
<p align="center">Made with ❤️ by the Sim Team</p>
|
||||
<p align="center">Made with ❤️ by the Sim Team</p>
|
||||
|
||||
121
apps/docs/content/docs/copilot/index.mdx
Normal file
121
apps/docs/content/docs/copilot/index.mdx
Normal file
@@ -0,0 +1,121 @@
|
||||
---
|
||||
title: Copilot
|
||||
description: Build and edit workflows with Sim Copilot
|
||||
---
|
||||
|
||||
import { Callout } from 'fumadocs-ui/components/callout'
|
||||
import { Card, Cards } from 'fumadocs-ui/components/card'
|
||||
import { MessageCircle, Package, Zap, Infinity as InfinityIcon, Brain, BrainCircuit } from 'lucide-react'
|
||||
|
||||
Copilot is your in-editor assistant that helps you build, understand, and improve workflows. It can:
|
||||
|
||||
- **Explain**: Answer questions about Sim and your current workflow
|
||||
- **Guide**: Suggest edits and best practices
|
||||
- **Edit**: Make changes to blocks, connections, and settings when you approve
|
||||
|
||||
<Callout type="info">
|
||||
Copilot is a Sim-managed service. For self-hosted deployments, generate a Copilot API key in the hosted app (sim.ai → Settings → Copilot)
|
||||
1. Go to [sim.ai](https://sim.ai) → Settings → Copilot and generate a Copilot API key
|
||||
2. Set `COPILOT_API_KEY` in your self-hosted environment to that value
|
||||
</Callout>
|
||||
|
||||
## Modes
|
||||
|
||||
<Cards>
|
||||
<Card
|
||||
title={
|
||||
<span className="inline-flex items-center gap-2">
|
||||
<MessageCircle className="h-4 w-4 text-muted-foreground" />
|
||||
Ask
|
||||
</span>
|
||||
}
|
||||
>
|
||||
<div className="m-0 text-sm">
|
||||
Q&A mode for explanations, guidance, and suggestions without making changes to your workflow.
|
||||
</div>
|
||||
</Card>
|
||||
<Card
|
||||
title={
|
||||
<span className="inline-flex items-center gap-2">
|
||||
<Package className="h-4 w-4 text-muted-foreground" />
|
||||
Agent
|
||||
</span>
|
||||
}
|
||||
>
|
||||
<div className="m-0 text-sm">
|
||||
Build-and-edit mode. Copilot proposes specific edits (add blocks, wire variables, tweak settings) and applies them when you approve.
|
||||
</div>
|
||||
</Card>
|
||||
</Cards>
|
||||
|
||||
## Depth Levels
|
||||
|
||||
<Cards>
|
||||
<Card
|
||||
title={
|
||||
<span className="inline-flex items-center gap-2">
|
||||
<Zap className="h-4 w-4 text-muted-foreground" />
|
||||
Fast
|
||||
</span>
|
||||
}
|
||||
>
|
||||
<div className="m-0 text-sm">Quickest and cheapest. Best for small edits, simple workflows, and minor tweaks.</div>
|
||||
</Card>
|
||||
<Card
|
||||
title={
|
||||
<span className="inline-flex items-center gap-2">
|
||||
<InfinityIcon className="h-4 w-4 text-muted-foreground" />
|
||||
Auto
|
||||
</span>
|
||||
}
|
||||
>
|
||||
<div className="m-0 text-sm">Balanced speed and reasoning. Recommended default for most tasks.</div>
|
||||
</Card>
|
||||
<Card
|
||||
title={
|
||||
<span className="inline-flex items-center gap-2">
|
||||
<Brain className="h-4 w-4 text-muted-foreground" />
|
||||
Advanced
|
||||
</span>
|
||||
}
|
||||
>
|
||||
<div className="m-0 text-sm">More reasoning for larger workflows and complex edits while staying performant.</div>
|
||||
</Card>
|
||||
<Card
|
||||
title={
|
||||
<span className="inline-flex items-center gap-2">
|
||||
<BrainCircuit className="h-4 w-4 text-muted-foreground" />
|
||||
Behemoth
|
||||
</span>
|
||||
}
|
||||
>
|
||||
<div className="m-0 text-sm">Maximum reasoning for deep planning, debugging, and complex architectural changes.</div>
|
||||
</Card>
|
||||
</Cards>
|
||||
|
||||
## Billing and Cost Calculation
|
||||
|
||||
### How Costs Are Calculated
|
||||
|
||||
Copilot usage is billed per token from the underlying LLM:
|
||||
|
||||
- **Input tokens**: billed at the provider's base rate (**at-cost**)
|
||||
- **Output tokens**: billed at **1.5×** the provider's base output rate
|
||||
|
||||
```javascript
|
||||
copilotCost = (inputTokens × inputPrice + outputTokens × (outputPrice × 1.5)) / 1,000,000
|
||||
```
|
||||
|
||||
| Component | Rate Applied |
|
||||
|----------|----------------------|
|
||||
| Input | inputPrice |
|
||||
| Output | outputPrice × 1.5 |
|
||||
|
||||
<Callout type="warning">
|
||||
Pricing shown reflects rates as of September 4, 2025. Check provider documentation for current pricing.
|
||||
</Callout>
|
||||
|
||||
<Callout type="info">
|
||||
Model prices are per million tokens. The calculation divides by 1,000,000 to get the actual cost. See <a href="/execution/advanced#cost-calculation">Logging and Cost Calculation</a> for background and examples.
|
||||
</Callout>
|
||||
|
||||
4
apps/docs/content/docs/copilot/meta.json
Normal file
4
apps/docs/content/docs/copilot/meta.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"title": "Copilot",
|
||||
"pages": ["index"]
|
||||
}
|
||||
@@ -12,6 +12,8 @@
|
||||
"connections",
|
||||
"---Execution---",
|
||||
"execution",
|
||||
"---Copilot---",
|
||||
"copilot",
|
||||
"---Advanced---",
|
||||
"./variables/index",
|
||||
"yaml",
|
||||
|
||||
@@ -33,12 +33,16 @@
|
||||
"microsoft_planner",
|
||||
"microsoft_teams",
|
||||
"mistral_parse",
|
||||
"mongodb",
|
||||
"mysql",
|
||||
"notion",
|
||||
"onedrive",
|
||||
"openai",
|
||||
"outlook",
|
||||
"parallel_ai",
|
||||
"perplexity",
|
||||
"pinecone",
|
||||
"postgresql",
|
||||
"qdrant",
|
||||
"reddit",
|
||||
"s3",
|
||||
|
||||
@@ -109,14 +109,13 @@ Read data from a Microsoft Excel spreadsheet
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `spreadsheetId` | string | Yes | The ID of the spreadsheet to read from |
|
||||
| `range` | string | No | The range of cells to read from |
|
||||
| `range` | string | No | The range of cells to read from. Accepts "SheetName!A1:B2" for explicit ranges or just "SheetName" to read the used range of that sheet. If omitted, reads the used range of the first sheet. |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `success` | boolean | Operation success status |
|
||||
| `output` | object | Excel spreadsheet data and metadata |
|
||||
| `data` | object | Range data from the spreadsheet |
|
||||
|
||||
### `microsoft_excel_write`
|
||||
|
||||
@@ -136,8 +135,11 @@ Write data to a Microsoft Excel spreadsheet
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `success` | boolean | Operation success status |
|
||||
| `output` | object | Write operation results and metadata |
|
||||
| `updatedRange` | string | The range that was updated |
|
||||
| `updatedRows` | number | Number of rows that were updated |
|
||||
| `updatedColumns` | number | Number of columns that were updated |
|
||||
| `updatedCells` | number | Number of cells that were updated |
|
||||
| `metadata` | object | Spreadsheet metadata |
|
||||
|
||||
### `microsoft_excel_table_add`
|
||||
|
||||
@@ -155,8 +157,9 @@ Add new rows to a Microsoft Excel table
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `success` | boolean | Operation success status |
|
||||
| `output` | object | Table add operation results and metadata |
|
||||
| `index` | number | Index of the first row that was added |
|
||||
| `values` | array | Array of rows that were added to the table |
|
||||
| `metadata` | object | Spreadsheet metadata |
|
||||
|
||||
|
||||
|
||||
|
||||
264
apps/docs/content/docs/tools/mongodb.mdx
Normal file
264
apps/docs/content/docs/tools/mongodb.mdx
Normal file
@@ -0,0 +1,264 @@
|
||||
---
|
||||
title: MongoDB
|
||||
description: Connect to MongoDB database
|
||||
---
|
||||
|
||||
import { BlockInfoCard } from "@/components/ui/block-info-card"
|
||||
|
||||
<BlockInfoCard
|
||||
type="mongodb"
|
||||
color="#E0E0E0"
|
||||
icon={true}
|
||||
iconSvg={`<svg className="block-icon" xmlns='http://www.w3.org/2000/svg' viewBox='0 0 128 128'>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='currentColor'
|
||||
d='M88.038 42.812c1.605 4.643 2.761 9.383 3.141 14.296.472 6.095.256 12.147-1.029 18.142-.035.165-.109.32-.164.48-.403.001-.814-.049-1.208.012-3.329.523-6.655 1.065-9.981 1.604-3.438.557-6.881 1.092-10.313 1.687-1.216.21-2.721-.041-3.212 1.641-.014.046-.154.054-.235.08l.166-10.051-.169-24.252 1.602-.275c2.62-.429 5.24-.864 7.862-1.281 3.129-.497 6.261-.98 9.392-1.465 1.381-.215 2.764-.412 4.148-.618z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#45A538'
|
||||
d='M61.729 110.054c-1.69-1.453-3.439-2.842-5.059-4.37-8.717-8.222-15.093-17.899-18.233-29.566-.865-3.211-1.442-6.474-1.627-9.792-.13-2.322-.318-4.665-.154-6.975.437-6.144 1.325-12.229 3.127-18.147l.099-.138c.175.233.427.439.516.702 1.759 5.18 3.505 10.364 5.242 15.551 5.458 16.3 10.909 32.604 16.376 48.9.107.318.384.579.583.866l-.87 2.969z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#46A037'
|
||||
d='M88.038 42.812c-1.384.206-2.768.403-4.149.616-3.131.485-6.263.968-9.392 1.465-2.622.417-5.242.852-7.862 1.281l-1.602.275-.012-1.045c-.053-.859-.144-1.717-.154-2.576-.069-5.478-.112-10.956-.18-16.434-.042-3.429-.105-6.857-.175-10.285-.043-2.13-.089-4.261-.185-6.388-.052-1.143-.236-2.28-.311-3.423-.042-.657.016-1.319.029-1.979.817 1.583 1.616 3.178 2.456 4.749 1.327 2.484 3.441 4.314 5.344 6.311 7.523 7.892 12.864 17.068 16.193 27.433z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#409433'
|
||||
d='M65.036 80.753c.081-.026.222-.034.235-.08.491-1.682 1.996-1.431 3.212-1.641 3.432-.594 6.875-1.13 10.313-1.687 3.326-.539 6.652-1.081 9.981-1.604.394-.062.805-.011 1.208-.012-.622 2.22-1.112 4.488-1.901 6.647-.896 2.449-1.98 4.839-3.131 7.182a49.142 49.142 0 01-6.353 9.763c-1.919 2.308-4.058 4.441-6.202 6.548-1.185 1.165-2.582 2.114-3.882 3.161l-.337-.23-1.214-1.038-1.256-2.753a41.402 41.402 0 01-1.394-9.838l.023-.561.171-2.426c.057-.828.133-1.655.168-2.485.129-2.982.241-5.964.359-8.946z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#4FAA41'
|
||||
d='M65.036 80.753c-.118 2.982-.23 5.964-.357 8.947-.035.83-.111 1.657-.168 2.485l-.765.289c-1.699-5.002-3.399-9.951-5.062-14.913-2.75-8.209-5.467-16.431-8.213-24.642a4498.887 4498.887 0 00-6.7-19.867c-.105-.31-.407-.552-.617-.826l4.896-9.002c.168.292.39.565.496.879a6167.476 6167.476 0 016.768 20.118c2.916 8.73 5.814 17.467 8.728 26.198.116.349.308.671.491 1.062l.67-.78-.167 10.052z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#4AA73C'
|
||||
d='M43.155 32.227c.21.274.511.516.617.826a4498.887 4498.887 0 016.7 19.867c2.746 8.211 5.463 16.433 8.213 24.642 1.662 4.961 3.362 9.911 5.062 14.913l.765-.289-.171 2.426-.155.559c-.266 2.656-.49 5.318-.814 7.968-.163 1.328-.509 2.632-.772 3.947-.198-.287-.476-.548-.583-.866-5.467-16.297-10.918-32.6-16.376-48.9a3888.972 3888.972 0 00-5.242-15.551c-.089-.263-.34-.469-.516-.702l3.272-8.84z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#57AE47'
|
||||
d='M65.202 70.702l-.67.78c-.183-.391-.375-.714-.491-1.062-2.913-8.731-5.812-17.468-8.728-26.198a6167.476 6167.476 0 00-6.768-20.118c-.105-.314-.327-.588-.496-.879l6.055-7.965c.191.255.463.482.562.769 1.681 4.921 3.347 9.848 5.003 14.778 1.547 4.604 3.071 9.215 4.636 13.813.105.308.47.526.714.786l.012 1.045c.058 8.082.115 16.167.171 24.251z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#60B24F'
|
||||
d='M65.021 45.404c-.244-.26-.609-.478-.714-.786-1.565-4.598-3.089-9.209-4.636-13.813-1.656-4.93-3.322-9.856-5.003-14.778-.099-.287-.371-.514-.562-.769 1.969-1.928 3.877-3.925 5.925-5.764 1.821-1.634 3.285-3.386 3.352-5.968.003-.107.059-.214.145-.514l.519 1.306c-.013.661-.072 1.322-.029 1.979.075 1.143.259 2.28.311 3.423.096 2.127.142 4.258.185 6.388.069 3.428.132 6.856.175 10.285.067 5.478.111 10.956.18 16.434.008.861.098 1.718.152 2.577z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#A9AA88'
|
||||
d='M62.598 107.085c.263-1.315.609-2.62.772-3.947.325-2.649.548-5.312.814-7.968l.066-.01.066.011a41.402 41.402 0 001.394 9.838c-.176.232-.425.439-.518.701-.727 2.05-1.412 4.116-2.143 6.166-.1.28-.378.498-.574.744l-.747-2.566.87-2.969z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#B6B598'
|
||||
d='M62.476 112.621c.196-.246.475-.464.574-.744.731-2.05 1.417-4.115 2.143-6.166.093-.262.341-.469.518-.701l1.255 2.754c-.248.352-.59.669-.728 1.061l-2.404 7.059c-.099.283-.437.483-.663.722l-.695-3.985z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#C2C1A7'
|
||||
d='M63.171 116.605c.227-.238.564-.439.663-.722l2.404-7.059c.137-.391.48-.709.728-1.061l1.215 1.037c-.587.58-.913 1.25-.717 2.097l-.369 1.208c-.168.207-.411.387-.494.624-.839 2.403-1.64 4.819-2.485 7.222-.107.305-.404.544-.614.812-.109-1.387-.22-2.771-.331-4.158z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#CECDB7'
|
||||
d='M63.503 120.763c.209-.269.506-.508.614-.812.845-2.402 1.646-4.818 2.485-7.222.083-.236.325-.417.494-.624l-.509 5.545c-.136.157-.333.294-.398.477-.575 1.614-1.117 3.24-1.694 4.854-.119.333-.347.627-.525.938-.158-.207-.441-.407-.454-.623-.051-.841-.016-1.688-.013-2.533z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#DBDAC7'
|
||||
d='M63.969 123.919c.178-.312.406-.606.525-.938.578-1.613 1.119-3.239 1.694-4.854.065-.183.263-.319.398-.477l.012 3.64-1.218 3.124-1.411-.495z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#EBE9DC'
|
||||
d='M65.38 124.415l1.218-3.124.251 3.696-1.469-.572z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#CECDB7'
|
||||
d='M67.464 110.898c-.196-.847.129-1.518.717-2.097l.337.23-1.054 1.867z'
|
||||
/>
|
||||
<path
|
||||
fillRule='evenodd'
|
||||
clipRule='evenodd'
|
||||
fill='#4FAA41'
|
||||
d='M64.316 95.172l-.066-.011-.066.01.155-.559-.023.56z'
|
||||
/>
|
||||
</svg>`}
|
||||
/>
|
||||
|
||||
## Usage Instructions
|
||||
|
||||
Connect to any MongoDB database to execute queries, manage data, and perform database operations. Supports find, insert, update, delete, and aggregation operations with secure connection handling.
|
||||
|
||||
|
||||
|
||||
## Tools
|
||||
|
||||
### `mongodb_query`
|
||||
|
||||
Execute find operation on MongoDB collection
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | MongoDB server hostname or IP address |
|
||||
| `port` | number | Yes | MongoDB server port \(default: 27017\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | No | MongoDB username |
|
||||
| `password` | string | No | MongoDB password |
|
||||
| `authSource` | string | No | Authentication database |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `collection` | string | Yes | Collection name to query |
|
||||
| `query` | string | No | MongoDB query filter as JSON string |
|
||||
| `limit` | number | No | Maximum number of documents to return |
|
||||
| `sort` | string | No | Sort criteria as JSON string |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `documents` | array | Array of documents returned from the query |
|
||||
| `documentCount` | number | Number of documents returned |
|
||||
|
||||
### `mongodb_insert`
|
||||
|
||||
Insert documents into MongoDB collection
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | MongoDB server hostname or IP address |
|
||||
| `port` | number | Yes | MongoDB server port \(default: 27017\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | No | MongoDB username |
|
||||
| `password` | string | No | MongoDB password |
|
||||
| `authSource` | string | No | Authentication database |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `collection` | string | Yes | Collection name to insert into |
|
||||
| `documents` | array | Yes | Array of documents to insert |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `documentCount` | number | Number of documents inserted |
|
||||
| `insertedId` | string | ID of inserted document \(single insert\) |
|
||||
| `insertedIds` | array | Array of inserted document IDs \(multiple insert\) |
|
||||
|
||||
### `mongodb_update`
|
||||
|
||||
Update documents in MongoDB collection
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | MongoDB server hostname or IP address |
|
||||
| `port` | number | Yes | MongoDB server port \(default: 27017\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | No | MongoDB username |
|
||||
| `password` | string | No | MongoDB password |
|
||||
| `authSource` | string | No | Authentication database |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `collection` | string | Yes | Collection name to update |
|
||||
| `filter` | string | Yes | Filter criteria as JSON string |
|
||||
| `update` | string | Yes | Update operations as JSON string |
|
||||
| `upsert` | boolean | No | Create document if not found |
|
||||
| `multi` | boolean | No | Update multiple documents |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `matchedCount` | number | Number of documents matched by filter |
|
||||
| `modifiedCount` | number | Number of documents modified |
|
||||
| `documentCount` | number | Total number of documents affected |
|
||||
| `insertedId` | string | ID of inserted document \(if upsert\) |
|
||||
|
||||
### `mongodb_delete`
|
||||
|
||||
Delete documents from MongoDB collection
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | MongoDB server hostname or IP address |
|
||||
| `port` | number | Yes | MongoDB server port \(default: 27017\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | No | MongoDB username |
|
||||
| `password` | string | No | MongoDB password |
|
||||
| `authSource` | string | No | Authentication database |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `collection` | string | Yes | Collection name to delete from |
|
||||
| `filter` | string | Yes | Filter criteria as JSON string |
|
||||
| `multi` | boolean | No | Delete multiple documents |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `deletedCount` | number | Number of documents deleted |
|
||||
| `documentCount` | number | Total number of documents affected |
|
||||
|
||||
### `mongodb_execute`
|
||||
|
||||
Execute MongoDB aggregation pipeline
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | MongoDB server hostname or IP address |
|
||||
| `port` | number | Yes | MongoDB server port \(default: 27017\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | No | MongoDB username |
|
||||
| `password` | string | No | MongoDB password |
|
||||
| `authSource` | string | No | Authentication database |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `collection` | string | Yes | Collection name to execute pipeline on |
|
||||
| `pipeline` | string | Yes | Aggregation pipeline as JSON string |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `documents` | array | Array of documents returned from aggregation |
|
||||
| `documentCount` | number | Number of documents returned |
|
||||
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- Category: `tools`
|
||||
- Type: `mongodb`
|
||||
180
apps/docs/content/docs/tools/mysql.mdx
Normal file
180
apps/docs/content/docs/tools/mysql.mdx
Normal file
@@ -0,0 +1,180 @@
|
||||
---
|
||||
title: MySQL
|
||||
description: Connect to MySQL database
|
||||
---
|
||||
|
||||
import { BlockInfoCard } from "@/components/ui/block-info-card"
|
||||
|
||||
<BlockInfoCard
|
||||
type="mysql"
|
||||
color="#E0E0E0"
|
||||
icon={true}
|
||||
iconSvg={`<svg className="block-icon"
|
||||
|
||||
xmlns='http://www.w3.org/2000/svg'
|
||||
|
||||
|
||||
viewBox='0 0 25.6 25.6'
|
||||
>
|
||||
<path
|
||||
d='M179.076 94.886c-3.568-.1-6.336.268-8.656 1.25-.668.27-1.74.27-1.828 1.116.357.355.4.936.713 1.428.535.893 1.473 2.096 2.32 2.72l2.855 2.053c1.74 1.07 3.703 1.695 5.398 2.766.982.625 1.963 1.428 2.945 2.098.5.357.803.938 1.428 1.16v-.135c-.312-.4-.402-.98-.713-1.428l-1.34-1.293c-1.293-1.74-2.9-3.258-4.64-4.506-1.428-.982-4.55-2.32-5.13-3.97l-.088-.1c.98-.1 2.14-.447 3.078-.715 1.518-.4 2.9-.312 4.46-.713l2.143-.625v-.4c-.803-.803-1.383-1.874-2.23-2.632-2.275-1.963-4.775-3.882-7.363-5.488-1.383-.892-3.168-1.473-4.64-2.23-.537-.268-1.428-.402-1.74-.848-.805-.98-1.25-2.275-1.83-3.436l-3.658-7.763c-.803-1.74-1.295-3.48-2.275-5.086-4.596-7.585-9.594-12.18-17.268-16.687-1.65-.937-3.613-1.34-5.7-1.83l-3.346-.18c-.715-.312-1.428-1.16-2.053-1.562-2.543-1.606-9.102-5.086-10.977-.5-1.205 2.9 1.785 5.755 2.8 7.228.76 1.026 1.74 2.186 2.277 3.346.3.758.4 1.562.713 2.365.713 1.963 1.383 4.15 2.32 5.98.5.937 1.025 1.92 1.65 2.767.357.5.982.714 1.115 1.517-.625.893-.668 2.23-1.025 3.347-1.607 5.042-.982 11.288 1.293 15 .715 1.115 2.4 3.57 4.686 2.632 2.008-.803 1.56-3.346 2.14-5.577.135-.535.045-.892.312-1.25v.1l1.83 3.703c1.383 2.186 3.793 4.462 5.8 5.98 1.07.803 1.918 2.187 3.256 2.677v-.135h-.088c-.268-.4-.67-.58-1.027-.892-.803-.803-1.695-1.785-2.32-2.677-1.873-2.498-3.523-5.265-4.996-8.12-.715-1.383-1.34-2.9-1.918-4.283-.27-.536-.27-1.34-.715-1.606-.67.98-1.65 1.83-2.143 3.034-.848 1.918-.936 4.283-1.248 6.737-.18.045-.1 0-.18.1-1.426-.356-1.918-1.83-2.453-3.078-1.338-3.168-1.562-8.254-.402-11.913.312-.937 1.652-3.882 1.117-4.774-.27-.848-1.16-1.338-1.652-2.008-.58-.848-1.203-1.918-1.605-2.855-1.07-2.5-1.605-5.265-2.766-7.764-.537-1.16-1.473-2.365-2.232-3.435-.848-1.205-1.783-2.053-2.453-3.48-.223-.5-.535-1.294-.178-1.83.088-.357.268-.5.623-.58.58-.5 2.232.134 2.812.4 1.65.67 3.033 1.294 4.416 2.23.625.446 1.295 1.294 2.098 1.518h.938c1.428.312 3.033.1 4.37.5 2.365.76 4.506 1.874 6.426 3.08 5.844 3.703 10.664 8.968 13.92 15.26.535 1.026.758 1.963 1.25 3.034.938 2.187 2.098 4.417 3.033 6.56.938 2.097 1.83 4.24 3.168 5.98.67.937 3.346 1.427 4.55 1.918.893.4 2.275.76 3.08 1.25 1.516.937 3.033 2.008 4.46 3.034.713.534 2.945 1.65 3.078 2.54zm-45.5-38.772a7.09 7.09 0 0 0-1.828.223v.1h.088c.357.714.982 1.205 1.428 1.83l1.027 2.142.088-.1c.625-.446.938-1.16.938-2.23-.268-.312-.312-.625-.535-.937-.268-.446-.848-.67-1.206-1.026z'
|
||||
transform='matrix(.390229 0 0 .38781 -46.300037 -16.856717)'
|
||||
fillRule='evenodd'
|
||||
fill='#00678c'
|
||||
/>
|
||||
</svg>`}
|
||||
/>
|
||||
|
||||
{/* MANUAL-CONTENT-START:intro */}
|
||||
The [MySQL](https://www.mysql.com/) tool enables you to connect to any MySQL database and perform a wide range of database operations directly within your agentic workflows. With secure connection handling and flexible configuration, you can easily manage and interact with your data.
|
||||
|
||||
With the MySQL tool, you can:
|
||||
|
||||
- **Query data**: Execute SELECT queries to retrieve data from your MySQL tables using the `mysql_query` operation.
|
||||
- **Insert records**: Add new rows to your tables with the `mysql_insert` operation by specifying the table and data to insert.
|
||||
- **Update records**: Modify existing data in your tables using the `mysql_update` operation, providing the table, new data, and WHERE conditions.
|
||||
- **Delete records**: Remove rows from your tables with the `mysql_delete` operation, specifying the table and WHERE conditions.
|
||||
- **Execute raw SQL**: Run any custom SQL command using the `mysql_execute` operation for advanced use cases.
|
||||
|
||||
The MySQL tool is ideal for scenarios where your agents need to interact with structured data—such as automating reporting, syncing data between systems, or powering data-driven workflows. It streamlines database access, making it easy to read, write, and manage your MySQL data programmatically.
|
||||
{/* MANUAL-CONTENT-END */}
|
||||
|
||||
|
||||
## Usage Instructions
|
||||
|
||||
Connect to any MySQL database to execute queries, manage data, and perform database operations. Supports SELECT, INSERT, UPDATE, DELETE operations with secure connection handling.
|
||||
|
||||
|
||||
|
||||
## Tools
|
||||
|
||||
### `mysql_query`
|
||||
|
||||
Execute SELECT query on MySQL database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | MySQL server hostname or IP address |
|
||||
| `port` | number | Yes | MySQL server port \(default: 3306\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | Yes | Database username |
|
||||
| `password` | string | Yes | Database password |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `query` | string | Yes | SQL SELECT query to execute |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `rows` | array | Array of rows returned from the query |
|
||||
| `rowCount` | number | Number of rows returned |
|
||||
|
||||
### `mysql_insert`
|
||||
|
||||
Insert new record into MySQL database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | MySQL server hostname or IP address |
|
||||
| `port` | number | Yes | MySQL server port \(default: 3306\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | Yes | Database username |
|
||||
| `password` | string | Yes | Database password |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `table` | string | Yes | Table name to insert into |
|
||||
| `data` | object | Yes | Data to insert as key-value pairs |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `rows` | array | Array of inserted rows |
|
||||
| `rowCount` | number | Number of rows inserted |
|
||||
|
||||
### `mysql_update`
|
||||
|
||||
Update existing records in MySQL database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | MySQL server hostname or IP address |
|
||||
| `port` | number | Yes | MySQL server port \(default: 3306\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | Yes | Database username |
|
||||
| `password` | string | Yes | Database password |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `table` | string | Yes | Table name to update |
|
||||
| `data` | object | Yes | Data to update as key-value pairs |
|
||||
| `where` | string | Yes | WHERE clause condition \(without WHERE keyword\) |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `rows` | array | Array of updated rows |
|
||||
| `rowCount` | number | Number of rows updated |
|
||||
|
||||
### `mysql_delete`
|
||||
|
||||
Delete records from MySQL database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | MySQL server hostname or IP address |
|
||||
| `port` | number | Yes | MySQL server port \(default: 3306\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | Yes | Database username |
|
||||
| `password` | string | Yes | Database password |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `table` | string | Yes | Table name to delete from |
|
||||
| `where` | string | Yes | WHERE clause condition \(without WHERE keyword\) |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `rows` | array | Array of deleted rows |
|
||||
| `rowCount` | number | Number of rows deleted |
|
||||
|
||||
### `mysql_execute`
|
||||
|
||||
Execute raw SQL query on MySQL database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | MySQL server hostname or IP address |
|
||||
| `port` | number | Yes | MySQL server port \(default: 3306\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | Yes | Database username |
|
||||
| `password` | string | Yes | Database password |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `query` | string | Yes | Raw SQL query to execute |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `rows` | array | Array of rows returned from the query |
|
||||
| `rowCount` | number | Number of rows affected |
|
||||
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- Category: `tools`
|
||||
- Type: `mysql`
|
||||
@@ -68,7 +68,7 @@ Upload a file to OneDrive
|
||||
| `fileName` | string | Yes | The name of the file to upload |
|
||||
| `content` | string | Yes | The content of the file to upload |
|
||||
| `folderSelector` | string | No | Select the folder to upload the file to |
|
||||
| `folderId` | string | No | The ID of the folder to upload the file to \(internal use\) |
|
||||
| `manualFolderId` | string | No | Manually entered folder ID \(advanced mode\) |
|
||||
|
||||
#### Output
|
||||
|
||||
@@ -87,7 +87,7 @@ Create a new folder in OneDrive
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `folderName` | string | Yes | Name of the folder to create |
|
||||
| `folderSelector` | string | No | Select the parent folder to create the folder in |
|
||||
| `folderId` | string | No | ID of the parent folder \(internal use\) |
|
||||
| `manualFolderId` | string | No | Manually entered parent folder ID \(advanced mode\) |
|
||||
|
||||
#### Output
|
||||
|
||||
@@ -105,7 +105,7 @@ List files and folders in OneDrive
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `folderSelector` | string | No | Select the folder to list files from |
|
||||
| `folderId` | string | No | The ID of the folder to list files from \(internal use\) |
|
||||
| `manualFolderId` | string | No | The manually entered folder ID \(advanced mode\) |
|
||||
| `query` | string | No | A query to filter the files |
|
||||
| `pageSize` | number | No | The number of files to return |
|
||||
|
||||
|
||||
@@ -211,10 +211,27 @@ Read emails from Outlook
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `success` | boolean | Email read operation success status |
|
||||
| `messageCount` | number | Number of emails retrieved |
|
||||
| `messages` | array | Array of email message objects |
|
||||
| `message` | string | Success or status message |
|
||||
| `results` | array | Array of email message objects |
|
||||
|
||||
### `outlook_forward`
|
||||
|
||||
Forward an existing Outlook message to specified recipients
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `messageId` | string | Yes | The ID of the message to forward |
|
||||
| `to` | string | Yes | Recipient email address\(es\), comma-separated |
|
||||
| `comment` | string | No | Optional comment to include with the forwarded message |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Success or error message |
|
||||
| `results` | object | Delivery result details |
|
||||
|
||||
|
||||
|
||||
|
||||
106
apps/docs/content/docs/tools/parallel_ai.mdx
Normal file
106
apps/docs/content/docs/tools/parallel_ai.mdx
Normal file
@@ -0,0 +1,106 @@
|
||||
---
|
||||
title: Parallel AI
|
||||
description: Search with Parallel AI
|
||||
---
|
||||
|
||||
import { BlockInfoCard } from "@/components/ui/block-info-card"
|
||||
|
||||
<BlockInfoCard
|
||||
type="parallel_ai"
|
||||
color="#E0E0E0"
|
||||
icon={true}
|
||||
iconSvg={`<svg className="block-icon"
|
||||
|
||||
fill='currentColor'
|
||||
|
||||
|
||||
viewBox='0 0 271 270'
|
||||
xmlns='http://www.w3.org/2000/svg'
|
||||
>
|
||||
<path
|
||||
d='M267.804 105.65H193.828C194.026 106.814 194.187 107.996 194.349 109.178H76.6703C76.4546 110.736 76.2388 112.312 76.0591 113.87H1.63342C1.27387 116.198 0.950289 118.543 0.698608 120.925H75.3759C75.2501 122.483 75.1602 124.059 75.0703 125.617H195.949C196.003 126.781 196.057 127.962 196.093 129.144H270.68V125.384C270.195 118.651 269.242 112.061 267.804 105.65Z'
|
||||
fill='#1D1C1A'
|
||||
/>
|
||||
<path
|
||||
d='M195.949 144.401H75.0703C75.1422 145.977 75.2501 147.535 75.3759 149.093H0.698608C0.950289 151.457 1.2559 153.802 1.63342 156.148H76.0591C76.2388 157.724 76.4366 159.282 76.6703 160.84H194.349C194.187 162.022 194.008 163.186 193.828 164.367H267.804C269.242 157.957 270.195 151.367 270.68 144.634V140.874H196.093C196.057 142.055 196.003 143.219 195.949 144.401Z'
|
||||
fill='#1D1C1A'
|
||||
/>
|
||||
<path
|
||||
d='M190.628 179.642H80.3559C80.7514 181.218 81.1828 182.776 81.6143 184.334H9.30994C10.2448 186.715 11.2515 189.061 12.3121 191.389H83.7536C84.2749 192.965 84.7962 194.523 85.3535 196.08H185.594C185.163 197.262 184.732 198.426 184.282 199.608H254.519C258.6 192.177 261.98 184.316 264.604 176.114H191.455C191.185 177.296 190.898 178.46 190.61 179.642H190.628Z'
|
||||
fill='#1D1C1A'
|
||||
/>
|
||||
<path
|
||||
d='M177.666 214.883H93.3352C94.1082 216.458 94.9172 218.034 95.7441 219.574H29.8756C31.8351 221.992 33.8666 224.337 35.9699 226.63H99.6632C100.598 228.205 101.551 229.781 102.522 231.321H168.498C167.761 232.503 167.006 233.685 166.233 234.849H226.762C234.474 227.847 241.36 219.95 247.292 211.355H179.356C178.799 212.537 178.26 213.719 177.684 214.883H177.666Z'
|
||||
fill='#1D1C1A'
|
||||
/>
|
||||
<path
|
||||
d='M154.943 250.106H116.058C117.371 251.699 118.701 253.257 120.067 254.797H73.021C91.6094 264.431 112.715 269.946 135.096 270C135.24 270 135.366 270 135.492 270C135.618 270 135.761 270 135.887 270C164.04 269.911 190.178 261.28 211.805 246.56H157.748C156.813 247.742 155.878 248.924 154.925 250.088L154.943 250.106Z'
|
||||
fill='#1D1C1A'
|
||||
/>
|
||||
<path
|
||||
d='M116.059 19.9124H154.943C155.896 21.0764 156.831 22.2582 157.766 23.4401H211.823C190.179 8.72065 164.058 0.0895344 135.906 0C135.762 0 135.636 0 135.51 0C135.384 0 135.24 0 135.115 0C112.715 0.0716275 91.6277 5.56904 73.0393 15.2029H120.086C118.719 16.7429 117.389 18.3187 116.077 19.8945L116.059 19.9124Z'
|
||||
fill='#1D1C1A'
|
||||
/>
|
||||
<path
|
||||
d='M93.3356 55.1532H177.667C178.242 56.3171 178.799 57.499 179.339 58.6808H247.274C241.342 50.0855 234.457 42.1886 226.744 35.187H166.215C166.988 36.351 167.743 37.5328 168.48 38.7147H102.504C101.533 40.2726 100.58 41.8305 99.6456 43.4063H35.9523C33.831 45.6804 31.7996 48.0262 29.858 50.4616H95.7265C94.8996 52.0195 94.1086 53.5774 93.3176 55.1532H93.3356Z'
|
||||
fill='#1D1C1A'
|
||||
/>
|
||||
<path
|
||||
d='M80.3736 90.3758H190.646C190.933 91.5398 191.221 92.7216 191.491 93.9035H264.64C262.015 85.7021 258.636 77.841 254.555 70.4097H184.318C184.767 71.5736 185.199 72.7555 185.63 73.9373H85.3893C84.832 75.4952 84.2927 77.0531 83.7893 78.6289H12.3479C11.2872 80.9389 10.2805 83.2847 9.3457 85.6842H81.65C81.2186 87.2421 80.7871 88.8 80.3916 90.3758H80.3736Z'
|
||||
fill='#1D1C1A'
|
||||
/>
|
||||
</svg>`}
|
||||
/>
|
||||
|
||||
{/* MANUAL-CONTENT-START:intro */}
|
||||
[Parallel AI](https://parallel.ai/) is an advanced web search and content extraction platform designed to deliver comprehensive, high-quality results for any query. By leveraging intelligent processing and large-scale data extraction, Parallel AI enables users and agents to access, analyze, and synthesize information from across the web with speed and accuracy.
|
||||
|
||||
With Parallel AI, you can:
|
||||
|
||||
- **Search the web intelligently**: Retrieve relevant, up-to-date information from a wide range of sources
|
||||
- **Extract and summarize content**: Get concise, meaningful excerpts from web pages and documents
|
||||
- **Customize search objectives**: Tailor queries to specific needs or questions for targeted results
|
||||
- **Process results at scale**: Handle large volumes of search results with advanced processing options
|
||||
- **Integrate with workflows**: Use Parallel AI within Sim to automate research, content gathering, and knowledge extraction
|
||||
- **Control output granularity**: Specify the number of results and the amount of content per result
|
||||
- **Secure API access**: Protect your searches and data with API key authentication
|
||||
|
||||
In Sim, the Parallel AI integration empowers your agents to perform web searches and extract content programmatically. This enables powerful automation scenarios such as real-time research, competitive analysis, content monitoring, and knowledge base creation. By connecting Sim with Parallel AI, you unlock the ability for agents to gather, process, and utilize web data as part of your automated workflows.
|
||||
{/* MANUAL-CONTENT-END */}
|
||||
|
||||
|
||||
## Usage Instructions
|
||||
|
||||
Search the web using Parallel AI's advanced search capabilities. Get comprehensive results with intelligent processing and content extraction.
|
||||
|
||||
|
||||
|
||||
## Tools
|
||||
|
||||
### `parallel_search`
|
||||
|
||||
Search the web using Parallel AI. Provides comprehensive search results with intelligent processing and content extraction.
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `objective` | string | Yes | The search objective or question to answer |
|
||||
| `search_queries` | string | No | Optional comma-separated list of search queries to execute |
|
||||
| `processor` | string | No | Processing method: base or pro \(default: base\) |
|
||||
| `max_results` | number | No | Maximum number of results to return \(default: 5\) |
|
||||
| `max_chars_per_result` | number | No | Maximum characters per result \(default: 1500\) |
|
||||
| `apiKey` | string | Yes | Parallel AI API Key |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `results` | array | Search results with excerpts from relevant pages |
|
||||
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- Category: `tools`
|
||||
- Type: `parallel_ai`
|
||||
188
apps/docs/content/docs/tools/postgresql.mdx
Normal file
188
apps/docs/content/docs/tools/postgresql.mdx
Normal file
@@ -0,0 +1,188 @@
|
||||
---
|
||||
title: PostgreSQL
|
||||
description: Connect to PostgreSQL database
|
||||
---
|
||||
|
||||
import { BlockInfoCard } from "@/components/ui/block-info-card"
|
||||
|
||||
<BlockInfoCard
|
||||
type="postgresql"
|
||||
color="#336791"
|
||||
icon={true}
|
||||
iconSvg={`<svg className="block-icon"
|
||||
|
||||
|
||||
|
||||
viewBox='-4 0 264 264'
|
||||
xmlns='http://www.w3.org/2000/svg'
|
||||
preserveAspectRatio='xMinYMin meet'
|
||||
>
|
||||
<path d='M255.008 158.086c-1.535-4.649-5.556-7.887-10.756-8.664-2.452-.366-5.26-.21-8.583.475-5.792 1.195-10.089 1.65-13.225 1.738 11.837-19.985 21.462-42.775 27.003-64.228 8.96-34.689 4.172-50.492-1.423-57.64C233.217 10.847 211.614.683 185.552.372c-13.903-.17-26.108 2.575-32.475 4.549-5.928-1.046-12.302-1.63-18.99-1.738-12.537-.2-23.614 2.533-33.079 8.15-5.24-1.772-13.65-4.27-23.362-5.864-22.842-3.75-41.252-.828-54.718 8.685C6.622 25.672-.937 45.684.461 73.634c.444 8.874 5.408 35.874 13.224 61.48 4.492 14.718 9.282 26.94 14.237 36.33 7.027 13.315 14.546 21.156 22.987 23.972 4.731 1.576 13.327 2.68 22.368-4.85 1.146 1.388 2.675 2.767 4.704 4.048 2.577 1.625 5.728 2.953 8.875 3.74 11.341 2.835 21.964 2.126 31.027-1.848.056 1.612.099 3.152.135 4.482.06 2.157.12 4.272.199 6.25.537 13.374 1.447 23.773 4.143 31.049.148.4.347 1.01.557 1.657 1.345 4.118 3.594 11.012 9.316 16.411 5.925 5.593 13.092 7.308 19.656 7.308 3.292 0 6.433-.432 9.188-1.022 9.82-2.105 20.973-5.311 29.041-16.799 7.628-10.86 11.336-27.217 12.007-52.99.087-.729.167-1.425.244-2.088l.16-1.362 1.797.158.463.031c10.002.456 22.232-1.665 29.743-5.154 5.935-2.754 24.954-12.795 20.476-26.351' />
|
||||
<path
|
||||
d='M237.906 160.722c-29.74 6.135-31.785-3.934-31.785-3.934 31.4-46.593 44.527-105.736 33.2-120.211-30.904-39.485-84.399-20.811-85.292-20.327l-.287.052c-5.876-1.22-12.451-1.946-19.842-2.067-13.456-.22-23.664 3.528-31.41 9.402 0 0-95.43-39.314-90.991 49.444.944 18.882 27.064 142.873 58.218 105.422 11.387-13.695 22.39-25.274 22.39-25.274 5.464 3.63 12.006 5.482 18.864 4.817l.533-.452c-.166 1.7-.09 3.363.213 5.332-8.026 8.967-5.667 10.541-21.711 13.844-16.235 3.346-6.698 9.302-.471 10.86 7.549 1.887 25.013 4.561 36.813-11.958l-.47 1.885c3.144 2.519 5.352 16.383 4.982 28.952-.37 12.568-.617 21.197 1.86 27.937 2.479 6.74 4.948 21.905 26.04 17.386 17.623-3.777 26.756-13.564 28.027-29.89.901-11.606 2.942-9.89 3.07-20.267l1.637-4.912c1.887-15.733.3-20.809 11.157-18.448l2.64.232c7.99.363 18.45-1.286 24.589-4.139 13.218-6.134 21.058-16.377 8.024-13.686h.002'
|
||||
fill='#336791'
|
||||
/>
|
||||
<path
|
||||
d='M108.076 81.525c-2.68-.373-5.107-.028-6.335.902-.69.523-.904 1.129-.962 1.546-.154 1.105.62 2.327 1.096 2.957 1.346 1.784 3.312 3.01 5.258 3.28.282.04.563.058.842.058 3.245 0 6.196-2.527 6.456-4.392.325-2.336-3.066-3.893-6.355-4.35M196.86 81.599c-.256-1.831-3.514-2.353-6.606-1.923-3.088.43-6.082 1.824-5.832 3.659.2 1.427 2.777 3.863 5.827 3.863.258 0 .518-.017.78-.054 2.036-.282 3.53-1.575 4.24-2.32 1.08-1.136 1.706-2.402 1.591-3.225'
|
||||
fill='#FFF'
|
||||
/>
|
||||
<path
|
||||
d='M247.802 160.025c-1.134-3.429-4.784-4.532-10.848-3.28-18.005 3.716-24.453 1.142-26.57-.417 13.995-21.32 25.508-47.092 31.719-71.137 2.942-11.39 4.567-21.968 4.7-30.59.147-9.463-1.465-16.417-4.789-20.665-13.402-17.125-33.072-26.311-56.882-26.563-16.369-.184-30.199 4.005-32.88 5.183-5.646-1.404-11.801-2.266-18.502-2.376-12.288-.199-22.91 2.743-31.704 8.74-3.82-1.422-13.692-4.811-25.765-6.756-20.872-3.36-37.458-.814-49.294 7.571-14.123 10.006-20.643 27.892-19.38 53.16.425 8.501 5.269 34.653 12.913 59.698 10.062 32.964 21 51.625 32.508 55.464 1.347.449 2.9.763 4.613.763 4.198 0 9.345-1.892 14.7-8.33a529.832 529.832 0 0 1 20.261-22.926c4.524 2.428 9.494 3.784 14.577 3.92.01.133.023.266.035.398a117.66 117.66 0 0 0-2.57 3.175c-3.522 4.471-4.255 5.402-15.592 7.736-3.225.666-11.79 2.431-11.916 8.435-.136 6.56 10.125 9.315 11.294 9.607 4.074 1.02 7.999 1.523 11.742 1.523 9.103 0 17.114-2.992 23.516-8.781-.197 23.386.778 46.43 3.586 53.451 2.3 5.748 7.918 19.795 25.664 19.794 2.604 0 5.47-.303 8.623-.979 18.521-3.97 26.564-12.156 29.675-30.203 1.665-9.645 4.522-32.676 5.866-45.03 2.836.885 6.487 1.29 10.434 1.289 8.232 0 17.731-1.749 23.688-4.514 6.692-3.108 18.768-10.734 16.578-17.36zm-44.106-83.48c-.061 3.647-.563 6.958-1.095 10.414-.573 3.717-1.165 7.56-1.314 12.225-.147 4.54.42 9.26.968 13.825 1.108 9.22 2.245 18.712-2.156 28.078a36.508 36.508 0 0 1-1.95-4.009c-.547-1.326-1.735-3.456-3.38-6.404-6.399-11.476-21.384-38.35-13.713-49.316 2.285-3.264 8.084-6.62 22.64-4.813zm-17.644-61.787c21.334.471 38.21 8.452 50.158 23.72 9.164 11.711-.927 64.998-30.14 110.969a171.33 171.33 0 0 0-.886-1.117l-.37-.462c7.549-12.467 6.073-24.802 4.759-35.738-.54-4.488-1.05-8.727-.92-12.709.134-4.22.692-7.84 1.232-11.34.663-4.313 1.338-8.776 1.152-14.037.139-.552.195-1.204.122-1.978-.475-5.045-6.235-20.144-17.975-33.81-6.422-7.475-15.787-15.84-28.574-21.482 5.5-1.14 13.021-2.203 21.442-2.016zM66.674 175.778c-5.9 7.094-9.974 5.734-11.314 5.288-8.73-2.912-18.86-21.364-27.791-50.624-7.728-25.318-12.244-50.777-12.602-57.916-1.128-22.578 4.345-38.313 16.268-46.769 19.404-13.76 51.306-5.524 64.125-1.347-.184.182-.376.352-.558.537-21.036 21.244-20.537 57.54-20.485 59.759-.002.856.07 2.068.168 3.735.362 6.105 1.036 17.467-.764 30.334-1.672 11.957 2.014 23.66 10.111 32.109a36.275 36.275 0 0 0 2.617 2.468c-3.604 3.86-11.437 12.396-19.775 22.426zm22.479-29.993c-6.526-6.81-9.49-16.282-8.133-25.99 1.9-13.592 1.199-25.43.822-31.79-.053-.89-.1-1.67-.127-2.285 3.073-2.725 17.314-10.355 27.47-8.028 4.634 1.061 7.458 4.217 8.632 9.645 6.076 28.103.804 39.816-3.432 49.229-.873 1.939-1.698 3.772-2.402 5.668l-.546 1.466c-1.382 3.706-2.668 7.152-3.465 10.424-6.938-.02-13.687-2.984-18.819-8.34zm1.065 37.9c-2.026-.506-3.848-1.385-4.917-2.114.893-.42 2.482-.992 5.238-1.56 13.337-2.745 15.397-4.683 19.895-10.394 1.031-1.31 2.2-2.794 3.819-4.602l.002-.002c2.411-2.7 3.514-2.242 5.514-1.412 1.621.67 3.2 2.702 3.84 4.938.303 1.056.643 3.06-.47 4.62-9.396 13.156-23.088 12.987-32.921 10.526zm69.799 64.952c-16.316 3.496-22.093-4.829-25.9-14.346-2.457-6.144-3.665-33.85-2.808-64.447.011-.407-.047-.8-.159-1.17a15.444 15.444 0 0 0-.456-2.162c-1.274-4.452-4.379-8.176-8.104-9.72-1.48-.613-4.196-1.738-7.46-.903.696-2.868 1.903-6.107 3.212-9.614l.549-1.475c.618-1.663 1.394-3.386 2.214-5.21 4.433-9.848 10.504-23.337 3.915-53.81-2.468-11.414-10.71-16.988-23.204-15.693-7.49.775-14.343 3.797-17.761 5.53-.735.372-1.407.732-2.035 1.082.954-11.5 4.558-32.992 18.04-46.59 8.489-8.56 19.794-12.788 33.568-12.56 27.14.444 44.544 14.372 54.366 25.979 8.464 10.001 13.047 20.076 14.876 25.51-13.755-1.399-23.11 1.316-27.852 8.096-10.317 14.748 5.644 43.372 13.315 57.129 1.407 2.521 2.621 4.7 3.003 5.626 2.498 6.054 5.732 10.096 8.093 13.046.724.904 1.426 1.781 1.96 2.547-4.166 1.201-11.649 3.976-10.967 17.847-.55 6.96-4.461 39.546-6.448 51.059-2.623 15.21-8.22 20.875-23.957 24.25zm68.104-77.936c-4.26 1.977-11.389 3.46-18.161 3.779-7.48.35-11.288-.838-12.184-1.569-.42-8.644 2.797-9.547 6.202-10.503.535-.15 1.057-.297 1.561-.473.313.255.656.508 1.032.756 6.012 3.968 16.735 4.396 31.874 1.271l.166-.033c-2.042 1.909-5.536 4.471-10.49 6.772z'
|
||||
fill='#FFF'
|
||||
/>
|
||||
</svg>`}
|
||||
/>
|
||||
|
||||
{/* MANUAL-CONTENT-START:intro */}
|
||||
The [PostgreSQL](https://www.postgresql.org/) tool enables you to connect to any PostgreSQL database and perform a wide range of database operations directly within your agentic workflows. With secure connection handling and flexible configuration, you can easily manage and interact with your data.
|
||||
|
||||
With the PostgreSQL tool, you can:
|
||||
|
||||
- **Query data**: Execute SELECT queries to retrieve data from your PostgreSQL tables using the `postgresql_query` operation.
|
||||
- **Insert records**: Add new rows to your tables with the `postgresql_insert` operation by specifying the table and data to insert.
|
||||
- **Update records**: Modify existing data in your tables using the `postgresql_update` operation, providing the table, new data, and WHERE conditions.
|
||||
- **Delete records**: Remove rows from your tables with the `postgresql_delete` operation, specifying the table and WHERE conditions.
|
||||
- **Execute raw SQL**: Run any custom SQL command using the `postgresql_execute` operation for advanced use cases.
|
||||
|
||||
The PostgreSQL tool is ideal for scenarios where your agents need to interact with structured data—such as automating reporting, syncing data between systems, or powering data-driven workflows. It streamlines database access, making it easy to read, write, and manage your PostgreSQL data programmatically.
|
||||
{/* MANUAL-CONTENT-END */}
|
||||
|
||||
|
||||
## Usage Instructions
|
||||
|
||||
Connect to any PostgreSQL database to execute queries, manage data, and perform database operations. Supports SELECT, INSERT, UPDATE, DELETE operations with secure connection handling.
|
||||
|
||||
|
||||
|
||||
## Tools
|
||||
|
||||
### `postgresql_query`
|
||||
|
||||
Execute a SELECT query on PostgreSQL database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | PostgreSQL server hostname or IP address |
|
||||
| `port` | number | Yes | PostgreSQL server port \(default: 5432\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | Yes | Database username |
|
||||
| `password` | string | Yes | Database password |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `query` | string | Yes | SQL SELECT query to execute |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `rows` | array | Array of rows returned from the query |
|
||||
| `rowCount` | number | Number of rows returned |
|
||||
|
||||
### `postgresql_insert`
|
||||
|
||||
Insert data into PostgreSQL database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | PostgreSQL server hostname or IP address |
|
||||
| `port` | number | Yes | PostgreSQL server port \(default: 5432\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | Yes | Database username |
|
||||
| `password` | string | Yes | Database password |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `table` | string | Yes | Table name to insert data into |
|
||||
| `data` | object | Yes | Data object to insert \(key-value pairs\) |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `rows` | array | Inserted data \(if RETURNING clause used\) |
|
||||
| `rowCount` | number | Number of rows inserted |
|
||||
|
||||
### `postgresql_update`
|
||||
|
||||
Update data in PostgreSQL database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | PostgreSQL server hostname or IP address |
|
||||
| `port` | number | Yes | PostgreSQL server port \(default: 5432\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | Yes | Database username |
|
||||
| `password` | string | Yes | Database password |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `table` | string | Yes | Table name to update data in |
|
||||
| `data` | object | Yes | Data object with fields to update \(key-value pairs\) |
|
||||
| `where` | string | Yes | WHERE clause condition \(without WHERE keyword\) |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `rows` | array | Updated data \(if RETURNING clause used\) |
|
||||
| `rowCount` | number | Number of rows updated |
|
||||
|
||||
### `postgresql_delete`
|
||||
|
||||
Delete data from PostgreSQL database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | PostgreSQL server hostname or IP address |
|
||||
| `port` | number | Yes | PostgreSQL server port \(default: 5432\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | Yes | Database username |
|
||||
| `password` | string | Yes | Database password |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `table` | string | Yes | Table name to delete data from |
|
||||
| `where` | string | Yes | WHERE clause condition \(without WHERE keyword\) |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `rows` | array | Deleted data \(if RETURNING clause used\) |
|
||||
| `rowCount` | number | Number of rows deleted |
|
||||
|
||||
### `postgresql_execute`
|
||||
|
||||
Execute raw SQL query on PostgreSQL database
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `host` | string | Yes | PostgreSQL server hostname or IP address |
|
||||
| `port` | number | Yes | PostgreSQL server port \(default: 5432\) |
|
||||
| `database` | string | Yes | Database name to connect to |
|
||||
| `username` | string | Yes | Database username |
|
||||
| `password` | string | Yes | Database password |
|
||||
| `ssl` | string | No | SSL connection mode \(disabled, required, preferred\) |
|
||||
| `query` | string | Yes | Raw SQL query to execute |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `rows` | array | Array of rows returned from the query |
|
||||
| `rowCount` | number | Number of rows affected |
|
||||
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- Category: `tools`
|
||||
- Type: `postgresql`
|
||||
@@ -142,7 +142,7 @@ Get a single row from a Supabase table based on filter criteria
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `results` | object | The row data if found, null if not found |
|
||||
| `results` | array | Array containing the row data if found, empty array if not found |
|
||||
|
||||
### `supabase_update`
|
||||
|
||||
@@ -185,6 +185,26 @@ Delete rows from a Supabase table based on filter criteria
|
||||
| `message` | string | Operation status message |
|
||||
| `results` | array | Array of deleted records |
|
||||
|
||||
### `supabase_upsert`
|
||||
|
||||
Insert or update data in a Supabase table (upsert operation)
|
||||
|
||||
#### Input
|
||||
|
||||
| Parameter | Type | Required | Description |
|
||||
| --------- | ---- | -------- | ----------- |
|
||||
| `projectId` | string | Yes | Your Supabase project ID \(e.g., jdrkgepadsdopsntdlom\) |
|
||||
| `table` | string | Yes | The name of the Supabase table to upsert data into |
|
||||
| `data` | any | Yes | The data to upsert \(insert or update\) |
|
||||
| `apiKey` | string | Yes | Your Supabase service role secret key |
|
||||
|
||||
#### Output
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| --------- | ---- | ----------- |
|
||||
| `message` | string | Operation status message |
|
||||
| `results` | array | Array of upserted records |
|
||||
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
# Database (Required)
|
||||
DATABASE_URL="postgresql://postgres:password@localhost:5432/postgres"
|
||||
|
||||
# PostgreSQL Port (Optional) - defaults to 5432 if not specified
|
||||
# POSTGRES_PORT=5432
|
||||
|
||||
# Authentication (Required)
|
||||
BETTER_AUTH_SECRET=your_secret_key # Use `openssl rand -hex 32` to generate, or visit https://www.better-auth.com/docs/installation
|
||||
BETTER_AUTH_URL=http://localhost:3000
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import { useEffect, useState } from 'react'
|
||||
import { GithubIcon, GoogleIcon } from '@/components/icons'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip'
|
||||
import { client } from '@/lib/auth-client'
|
||||
|
||||
interface SocialLoginButtonsProps {
|
||||
@@ -114,58 +113,16 @@ export function SocialLoginButtons({
|
||||
</Button>
|
||||
)
|
||||
|
||||
const renderGithubButton = () => {
|
||||
if (githubAvailable) return githubButton
|
||||
const hasAnyOAuthProvider = githubAvailable || googleAvailable
|
||||
|
||||
return (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>{githubButton}</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className='border-neutral-700 bg-neutral-800 text-white'>
|
||||
<p>
|
||||
GitHub login requires OAuth credentials to be configured. Add the following
|
||||
environment variables:
|
||||
</p>
|
||||
<ul className='mt-2 space-y-1 text-neutral-300 text-xs'>
|
||||
<li>• GITHUB_CLIENT_ID</li>
|
||||
<li>• GITHUB_CLIENT_SECRET</li>
|
||||
</ul>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)
|
||||
}
|
||||
|
||||
const renderGoogleButton = () => {
|
||||
if (googleAvailable) return googleButton
|
||||
|
||||
return (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>{googleButton}</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className='border-neutral-700 bg-neutral-800 text-white'>
|
||||
<p>
|
||||
Google login requires OAuth credentials to be configured. Add the following
|
||||
environment variables:
|
||||
</p>
|
||||
<ul className='mt-2 space-y-1 text-neutral-300 text-xs'>
|
||||
<li>• GOOGLE_CLIENT_ID</li>
|
||||
<li>• GOOGLE_CLIENT_SECRET</li>
|
||||
</ul>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
)
|
||||
if (!hasAnyOAuthProvider) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<div className='grid gap-3'>
|
||||
{renderGithubButton()}
|
||||
{renderGoogleButton()}
|
||||
{githubAvailable && githubButton}
|
||||
{googleAvailable && googleButton}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ export default function AuthLayout({ children }: { children: React.ReactNode })
|
||||
const brand = useBrandConfig()
|
||||
|
||||
return (
|
||||
<main className='relative flex min-h-screen flex-col bg-[#0C0C0C] font-geist-sans text-white'>
|
||||
<main className='relative flex min-h-screen flex-col bg-[var(--brand-background-hex)] font-geist-sans text-white'>
|
||||
{/* Background pattern */}
|
||||
<GridPattern
|
||||
x={-5}
|
||||
@@ -28,12 +28,12 @@ export default function AuthLayout({ children }: { children: React.ReactNode })
|
||||
<img
|
||||
src={brand.logoUrl}
|
||||
alt={`${brand.name} Logo`}
|
||||
width={42}
|
||||
height={42}
|
||||
className='h-[42px] w-[42px] object-contain'
|
||||
width={56}
|
||||
height={56}
|
||||
className='h-[56px] w-[56px] object-contain'
|
||||
/>
|
||||
) : (
|
||||
<Image src='/sim.svg' alt={`${brand.name} Logo`} width={42} height={42} />
|
||||
<Image src='/sim.svg' alt={`${brand.name} Logo`} width={56} height={56} />
|
||||
)}
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
@@ -49,15 +49,12 @@ const PASSWORD_VALIDATIONS = {
|
||||
},
|
||||
}
|
||||
|
||||
// Validate callback URL to prevent open redirect vulnerabilities
|
||||
const validateCallbackUrl = (url: string): boolean => {
|
||||
try {
|
||||
// If it's a relative URL, it's safe
|
||||
if (url.startsWith('/')) {
|
||||
return true
|
||||
}
|
||||
|
||||
// If absolute URL, check if it belongs to the same origin
|
||||
const currentOrigin = typeof window !== 'undefined' ? window.location.origin : ''
|
||||
if (url.startsWith(currentOrigin)) {
|
||||
return true
|
||||
@@ -70,7 +67,6 @@ const validateCallbackUrl = (url: string): boolean => {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate password and return array of error messages
|
||||
const validatePassword = (passwordValue: string): string[] => {
|
||||
const errors: string[] = []
|
||||
|
||||
@@ -308,6 +304,15 @@ export default function LoginPage({
|
||||
return
|
||||
}
|
||||
|
||||
const emailValidation = quickValidateEmail(forgotPasswordEmail.trim().toLowerCase())
|
||||
if (!emailValidation.isValid) {
|
||||
setResetStatus({
|
||||
type: 'error',
|
||||
message: 'Please enter a valid email address',
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
setIsSubmittingReset(true)
|
||||
setResetStatus({ type: null, message: '' })
|
||||
@@ -325,7 +330,23 @@ export default function LoginPage({
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json()
|
||||
throw new Error(errorData.message || 'Failed to request password reset')
|
||||
let errorMessage = errorData.message || 'Failed to request password reset'
|
||||
|
||||
if (
|
||||
errorMessage.includes('Invalid body parameters') ||
|
||||
errorMessage.includes('invalid email')
|
||||
) {
|
||||
errorMessage = 'Please enter a valid email address'
|
||||
} else if (errorMessage.includes('Email is required')) {
|
||||
errorMessage = 'Please enter your email address'
|
||||
} else if (
|
||||
errorMessage.includes('user not found') ||
|
||||
errorMessage.includes('User not found')
|
||||
) {
|
||||
errorMessage = 'No account found with this email address'
|
||||
}
|
||||
|
||||
throw new Error(errorMessage)
|
||||
}
|
||||
|
||||
setResetStatus({
|
||||
@@ -366,11 +387,13 @@ export default function LoginPage({
|
||||
callbackURL={callbackUrl}
|
||||
/>
|
||||
|
||||
<div className='relative mt-2 py-4'>
|
||||
<div className='absolute inset-0 flex items-center'>
|
||||
<div className='w-full border-neutral-700/50 border-t' />
|
||||
{(githubAvailable || googleAvailable) && (
|
||||
<div className='relative mt-2 py-4'>
|
||||
<div className='absolute inset-0 flex items-center'>
|
||||
<div className='w-full border-neutral-700/50 border-t' />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<form onSubmit={onSubmit} className='space-y-5'>
|
||||
<div className='space-y-4'>
|
||||
@@ -456,7 +479,7 @@ export default function LoginPage({
|
||||
|
||||
<Button
|
||||
type='submit'
|
||||
className='flex h-11 w-full items-center justify-center gap-2 bg-[#701ffc] font-medium text-base text-white shadow-[#701ffc]/20 shadow-lg transition-colors duration-200 hover:bg-[#802FFF]'
|
||||
className='flex h-11 w-full items-center justify-center gap-2 bg-brand-primary font-medium text-base text-white shadow-[var(--brand-primary-hex)]/20 shadow-lg transition-colors duration-200 hover:bg-brand-primary-hover'
|
||||
disabled={isLoading}
|
||||
>
|
||||
{isLoading ? 'Signing in...' : 'Sign In'}
|
||||
@@ -468,11 +491,28 @@ export default function LoginPage({
|
||||
<span className='text-neutral-400'>Don't have an account? </span>
|
||||
<Link
|
||||
href={isInviteFlow ? `/signup?invite_flow=true&callbackUrl=${callbackUrl}` : '/signup'}
|
||||
className='font-medium text-[#9D54FF] underline-offset-4 transition hover:text-[#a66fff] hover:underline'
|
||||
className='font-medium text-[var(--brand-accent-hex)] underline-offset-4 transition hover:text-[var(--brand-accent-hover-hex)] hover:underline'
|
||||
>
|
||||
Sign up
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
<div className='text-center text-neutral-500/80 text-xs leading-relaxed'>
|
||||
By signing in, you agree to our{' '}
|
||||
<Link
|
||||
href='/terms'
|
||||
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
|
||||
>
|
||||
Terms of Service
|
||||
</Link>{' '}
|
||||
and{' '}
|
||||
<Link
|
||||
href='/privacy'
|
||||
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
|
||||
>
|
||||
Privacy Policy
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Dialog open={forgotPasswordOpen} onOpenChange={setForgotPasswordOpen}>
|
||||
@@ -482,7 +522,8 @@ export default function LoginPage({
|
||||
Reset Password
|
||||
</DialogTitle>
|
||||
<DialogDescription className='text-neutral-300 text-sm'>
|
||||
Enter your email address and we'll send you a link to reset your password.
|
||||
Enter your email address and we'll send you a link to reset your password if your
|
||||
account exists.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
<div className='space-y-4'>
|
||||
@@ -497,22 +538,26 @@ export default function LoginPage({
|
||||
placeholder='Enter your email'
|
||||
required
|
||||
type='email'
|
||||
className='border-neutral-700/80 bg-neutral-900 text-white placeholder:text-white/60 focus:border-[#802FFF]/70 focus:ring-[#802FFF]/20'
|
||||
className={cn(
|
||||
'border-neutral-700/80 bg-neutral-900 text-white placeholder:text-white/60 focus:border-[var(--brand-primary-hover-hex)]/70 focus:ring-[var(--brand-primary-hover-hex)]/20',
|
||||
resetStatus.type === 'error' && 'border-red-500 focus-visible:ring-red-500'
|
||||
)}
|
||||
/>
|
||||
{resetStatus.type === 'error' && (
|
||||
<div className='mt-1 space-y-1 text-red-400 text-xs'>
|
||||
<p>{resetStatus.message}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{resetStatus.type && (
|
||||
<div
|
||||
className={`text-sm ${
|
||||
resetStatus.type === 'success' ? 'text-[#4CAF50]' : 'text-red-500'
|
||||
}`}
|
||||
>
|
||||
{resetStatus.message}
|
||||
{resetStatus.type === 'success' && (
|
||||
<div className='mt-1 space-y-1 text-[#4CAF50] text-xs'>
|
||||
<p>{resetStatus.message}</p>
|
||||
</div>
|
||||
)}
|
||||
<Button
|
||||
type='button'
|
||||
onClick={handleForgotPassword}
|
||||
className='h-11 w-full bg-[#701ffc] font-medium text-base text-white shadow-[#701ffc]/20 shadow-lg transition-colors duration-200 hover:bg-[#802FFF]'
|
||||
className='h-11 w-full bg-[var(--brand-primary-hex)] font-medium text-base text-white shadow-[var(--brand-primary-hex)]/20 shadow-lg transition-colors duration-200 hover:bg-[var(--brand-primary-hover-hex)]'
|
||||
disabled={isSubmittingReset}
|
||||
>
|
||||
{isSubmittingReset ? 'Sending...' : 'Send Reset Link'}
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { client } from '@/lib/auth-client'
|
||||
import { client, useSession } from '@/lib/auth-client'
|
||||
import SignupPage from '@/app/(auth)/signup/signup-form'
|
||||
|
||||
vi.mock('next/navigation', () => ({
|
||||
@@ -22,6 +22,7 @@ vi.mock('@/lib/auth-client', () => ({
|
||||
sendVerificationOtp: vi.fn(),
|
||||
},
|
||||
},
|
||||
useSession: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/(auth)/components/social-login-buttons', () => ({
|
||||
@@ -43,6 +44,9 @@ describe('SignupPage', () => {
|
||||
vi.clearAllMocks()
|
||||
;(useRouter as any).mockReturnValue(mockRouter)
|
||||
;(useSearchParams as any).mockReturnValue(mockSearchParams)
|
||||
;(useSession as any).mockReturnValue({
|
||||
refetch: vi.fn().mockResolvedValue({}),
|
||||
})
|
||||
mockSearchParams.get.mockReturnValue(null)
|
||||
})
|
||||
|
||||
@@ -162,8 +166,9 @@ describe('SignupPage', () => {
|
||||
})
|
||||
})
|
||||
|
||||
it('should prevent submission with invalid name validation', async () => {
|
||||
it('should automatically trim spaces from name input', async () => {
|
||||
const mockSignUp = vi.mocked(client.signUp.email)
|
||||
mockSignUp.mockResolvedValue({ data: null, error: null })
|
||||
|
||||
render(<SignupPage {...defaultProps} />)
|
||||
|
||||
@@ -172,22 +177,20 @@ describe('SignupPage', () => {
|
||||
const passwordInput = screen.getByPlaceholderText(/enter your password/i)
|
||||
const submitButton = screen.getByRole('button', { name: /create account/i })
|
||||
|
||||
// Use name with leading/trailing spaces which should fail validation
|
||||
fireEvent.change(nameInput, { target: { value: ' John Doe ' } })
|
||||
fireEvent.change(emailInput, { target: { value: 'user@company.com' } })
|
||||
fireEvent.change(passwordInput, { target: { value: 'Password123!' } })
|
||||
fireEvent.click(submitButton)
|
||||
|
||||
// Should not call signUp because validation failed
|
||||
expect(mockSignUp).not.toHaveBeenCalled()
|
||||
|
||||
// Should show validation error
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(
|
||||
/Name cannot contain consecutive spaces|Name cannot start or end with spaces/
|
||||
)
|
||||
).toBeInTheDocument()
|
||||
expect(mockSignUp).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
name: 'John Doe',
|
||||
email: 'user@company.com',
|
||||
password: 'Password123!',
|
||||
}),
|
||||
expect.any(Object)
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { Button } from '@/components/ui/button'
|
||||
import { Input } from '@/components/ui/input'
|
||||
import { Label } from '@/components/ui/label'
|
||||
import { client } from '@/lib/auth-client'
|
||||
import { client, useSession } from '@/lib/auth-client'
|
||||
import { quickValidateEmail } from '@/lib/email/validation'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { cn } from '@/lib/utils'
|
||||
@@ -49,10 +49,6 @@ const NAME_VALIDATIONS = {
|
||||
regex: /^(?!.*\s\s).*$/,
|
||||
message: 'Name cannot contain consecutive spaces.',
|
||||
},
|
||||
noLeadingTrailingSpaces: {
|
||||
test: (value: string) => value === value.trim(),
|
||||
message: 'Name cannot start or end with spaces.',
|
||||
},
|
||||
}
|
||||
|
||||
const validateEmailField = (emailValue: string): string[] => {
|
||||
@@ -82,6 +78,7 @@ function SignupFormContent({
|
||||
}) {
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const { refetch: refetchSession } = useSession()
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const [, setMounted] = useState(false)
|
||||
const [showPassword, setShowPassword] = useState(false)
|
||||
@@ -174,10 +171,6 @@ function SignupFormContent({
|
||||
errors.push(NAME_VALIDATIONS.noConsecutiveSpaces.message)
|
||||
}
|
||||
|
||||
if (!NAME_VALIDATIONS.noLeadingTrailingSpaces.test(nameValue)) {
|
||||
errors.push(NAME_VALIDATIONS.noLeadingTrailingSpaces.message)
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
@@ -192,11 +185,10 @@ function SignupFormContent({
|
||||
}
|
||||
|
||||
const handleNameChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const newName = e.target.value
|
||||
setName(newName)
|
||||
const rawValue = e.target.value
|
||||
setName(rawValue)
|
||||
|
||||
// Silently validate but don't show errors until submit
|
||||
const errors = validateName(newName)
|
||||
const errors = validateName(rawValue)
|
||||
setNameErrors(errors)
|
||||
setShowNameValidationError(false)
|
||||
}
|
||||
@@ -223,23 +215,21 @@ function SignupFormContent({
|
||||
const formData = new FormData(e.currentTarget)
|
||||
const emailValue = formData.get('email') as string
|
||||
const passwordValue = formData.get('password') as string
|
||||
const name = formData.get('name') as string
|
||||
const nameValue = formData.get('name') as string
|
||||
|
||||
// Validate name on submit
|
||||
const nameValidationErrors = validateName(name)
|
||||
const trimmedName = nameValue.trim()
|
||||
|
||||
const nameValidationErrors = validateName(trimmedName)
|
||||
setNameErrors(nameValidationErrors)
|
||||
setShowNameValidationError(nameValidationErrors.length > 0)
|
||||
|
||||
// Validate email on submit
|
||||
const emailValidationErrors = validateEmailField(emailValue)
|
||||
setEmailErrors(emailValidationErrors)
|
||||
setShowEmailValidationError(emailValidationErrors.length > 0)
|
||||
|
||||
// Validate password on submit
|
||||
const errors = validatePassword(passwordValue)
|
||||
setPasswordErrors(errors)
|
||||
|
||||
// Only show validation errors if there are any
|
||||
setShowValidationError(errors.length > 0)
|
||||
|
||||
try {
|
||||
@@ -248,7 +238,6 @@ function SignupFormContent({
|
||||
emailValidationErrors.length > 0 ||
|
||||
errors.length > 0
|
||||
) {
|
||||
// Prioritize name errors first, then email errors, then password errors
|
||||
if (nameValidationErrors.length > 0) {
|
||||
setNameErrors([nameValidationErrors[0]])
|
||||
setShowNameValidationError(true)
|
||||
@@ -265,8 +254,6 @@ function SignupFormContent({
|
||||
return
|
||||
}
|
||||
|
||||
// Check if name will be truncated and warn user
|
||||
const trimmedName = name.trim()
|
||||
if (trimmedName.length > 100) {
|
||||
setNameErrors(['Name will be truncated to 100 characters. Please shorten your name.'])
|
||||
setShowNameValidationError(true)
|
||||
@@ -330,6 +317,14 @@ function SignupFormContent({
|
||||
return
|
||||
}
|
||||
|
||||
// Refresh session to get the new user data immediately after signup
|
||||
try {
|
||||
await refetchSession()
|
||||
logger.info('Session refreshed after successful signup')
|
||||
} catch (sessionError) {
|
||||
logger.error('Failed to refresh session after signup:', sessionError)
|
||||
}
|
||||
|
||||
// For new signups, always require verification
|
||||
if (typeof window !== 'undefined') {
|
||||
sessionStorage.setItem('verificationEmail', emailValue)
|
||||
@@ -381,11 +376,13 @@ function SignupFormContent({
|
||||
isProduction={isProduction}
|
||||
/>
|
||||
|
||||
<div className='relative mt-2 py-4'>
|
||||
<div className='absolute inset-0 flex items-center'>
|
||||
<div className='w-full border-neutral-700/50 border-t' />
|
||||
{(githubAvailable || googleAvailable) && (
|
||||
<div className='relative mt-2 py-4'>
|
||||
<div className='absolute inset-0 flex items-center'>
|
||||
<div className='w-full border-neutral-700/50 border-t' />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<form onSubmit={onSubmit} className='space-y-5'>
|
||||
<div className='space-y-4'>
|
||||
@@ -488,7 +485,7 @@ function SignupFormContent({
|
||||
|
||||
<Button
|
||||
type='submit'
|
||||
className='flex h-11 w-full items-center justify-center gap-2 bg-[#701ffc] font-medium text-base text-white shadow-[#701ffc]/20 shadow-lg transition-colors duration-200 hover:bg-[#802FFF]'
|
||||
className='flex h-11 w-full items-center justify-center gap-2 bg-[var(--brand-primary-hex)] font-medium text-base text-white shadow-[var(--brand-primary-hex)]/20 shadow-lg transition-colors duration-200 hover:bg-[var(--brand-primary-hover-hex)]'
|
||||
disabled={isLoading}
|
||||
>
|
||||
{isLoading ? 'Creating account...' : 'Create Account'}
|
||||
@@ -500,11 +497,28 @@ function SignupFormContent({
|
||||
<span className='text-neutral-400'>Already have an account? </span>
|
||||
<Link
|
||||
href={isInviteFlow ? `/login?invite_flow=true&callbackUrl=${redirectUrl}` : '/login'}
|
||||
className='font-medium text-[#9D54FF] underline-offset-4 transition hover:text-[#a66fff] hover:underline'
|
||||
className='font-medium text-[var(--brand-accent-hex)] underline-offset-4 transition hover:text-[var(--brand-accent-hover-hex)] hover:underline'
|
||||
>
|
||||
Sign in
|
||||
</Link>
|
||||
</div>
|
||||
|
||||
<div className='text-center text-neutral-500/80 text-xs leading-relaxed'>
|
||||
By creating an account, you agree to our{' '}
|
||||
<Link
|
||||
href='/terms'
|
||||
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
|
||||
>
|
||||
Terms of Service
|
||||
</Link>{' '}
|
||||
and{' '}
|
||||
<Link
|
||||
href='/privacy'
|
||||
className='text-neutral-400 underline-offset-4 transition hover:text-neutral-300 hover:underline'
|
||||
>
|
||||
Privacy Policy
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import { client } from '@/lib/auth-client'
|
||||
import { client, useSession } from '@/lib/auth-client'
|
||||
import { env, isTruthy } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
@@ -34,6 +34,7 @@ export function useVerification({
|
||||
}: UseVerificationParams): UseVerificationReturn {
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const { refetch: refetchSession } = useSession()
|
||||
const [otp, setOtp] = useState('')
|
||||
const [email, setEmail] = useState('')
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
@@ -136,16 +137,15 @@ export function useVerification({
|
||||
}
|
||||
}
|
||||
|
||||
// Redirect to proper page after a short delay
|
||||
setTimeout(() => {
|
||||
if (isInviteFlow && redirectUrl) {
|
||||
// For invitation flow, redirect to the invitation page
|
||||
router.push(redirectUrl)
|
||||
window.location.href = redirectUrl
|
||||
} else {
|
||||
// Default redirect to dashboard
|
||||
router.push('/workspace')
|
||||
window.location.href = '/workspace'
|
||||
}
|
||||
}, 2000)
|
||||
}, 1000)
|
||||
} else {
|
||||
logger.info('Setting invalid OTP state - API error response')
|
||||
const message = 'Invalid verification code. Please check and try again.'
|
||||
@@ -215,25 +215,33 @@ export function useVerification({
|
||||
setOtp(value)
|
||||
}
|
||||
|
||||
// Auto-submit when OTP is complete
|
||||
useEffect(() => {
|
||||
if (otp.length === 6 && email && !isLoading && !isVerified) {
|
||||
const timeoutId = setTimeout(() => {
|
||||
verifyCode()
|
||||
}, 300) // Small delay to ensure UI is ready
|
||||
|
||||
return () => clearTimeout(timeoutId)
|
||||
}
|
||||
}, [otp, email, isLoading, isVerified])
|
||||
|
||||
useEffect(() => {
|
||||
if (typeof window !== 'undefined') {
|
||||
if (!isProduction || !hasResendKey) {
|
||||
const storedEmail = sessionStorage.getItem('verificationEmail')
|
||||
logger.info('Auto-verifying user', { email: storedEmail })
|
||||
}
|
||||
|
||||
const isDevOrDocker = !isProduction || isTruthy(env.DOCKER_BUILD)
|
||||
|
||||
// Auto-verify and redirect in development/docker environments
|
||||
if (isDevOrDocker || !hasResendKey) {
|
||||
setIsVerified(true)
|
||||
|
||||
// Clear verification requirement cookie (same as manual verification)
|
||||
document.cookie =
|
||||
'requiresEmailVerification=; path=/; expires=Thu, 01 Jan 1970 00:00:00 GMT'
|
||||
|
||||
const timeoutId = setTimeout(() => {
|
||||
router.push('/workspace')
|
||||
window.location.href = '/workspace'
|
||||
}, 1000)
|
||||
|
||||
return () => clearTimeout(timeoutId)
|
||||
|
||||
@@ -124,7 +124,7 @@ function VerificationForm({
|
||||
|
||||
<Button
|
||||
onClick={verifyCode}
|
||||
className='h-11 w-full bg-[#701ffc] font-medium text-base text-white shadow-[#701ffc]/20 shadow-lg transition-colors duration-200 hover:bg-[#802FFF]'
|
||||
className='h-11 w-full bg-[var(--brand-primary-hex)] font-medium text-base text-white shadow-[var(--brand-primary-hex)]/20 shadow-lg transition-colors duration-200 hover:bg-[var(--brand-primary-hover-hex)]'
|
||||
disabled={!isOtpComplete || isLoading}
|
||||
>
|
||||
{isLoading ? 'Verifying...' : 'Verify Email'}
|
||||
@@ -140,7 +140,7 @@ function VerificationForm({
|
||||
</span>
|
||||
) : (
|
||||
<button
|
||||
className='font-medium text-[#9D54FF] underline-offset-4 transition hover:text-[#a66fff] hover:underline'
|
||||
className='font-medium text-[var(--brand-accent-hex)] underline-offset-4 transition hover:text-[var(--brand-accent-hover-hex)] hover:underline'
|
||||
onClick={handleResend}
|
||||
disabled={isLoading || isResendDisabled}
|
||||
>
|
||||
|
||||
@@ -35,7 +35,7 @@ export const BlogCard = ({
|
||||
}: BlogCardProps) => {
|
||||
return (
|
||||
<Link href={href}>
|
||||
<div className='flex flex-col rounded-3xl border border-[#606060]/40 bg-[#101010] p-8 transition-all duration-500 hover:bg-[#202020]'>
|
||||
<div className='flex flex-col rounded-3xl border border-[#606060]/40 bg-[#101010] p-8 transition-all duration-500 hover:bg-[var(--surface-elevated)]'>
|
||||
{image ? (
|
||||
<Image
|
||||
src={image}
|
||||
|
||||
@@ -245,7 +245,7 @@ export default function NavClient({
|
||||
target='_blank'
|
||||
rel='noopener noreferrer'
|
||||
>
|
||||
<Button className='h-[43px] bg-[#701ffc] px-6 py-2 font-geist-sans font-medium text-base text-neutral-100 transition-colors duration-200 hover:bg-[#802FFF]'>
|
||||
<Button className='h-[43px] bg-[var(--brand-primary-hex)] px-6 py-2 font-geist-sans font-medium text-base text-neutral-100 transition-colors duration-200 hover:bg-[var(--brand-primary-hover-hex)]'>
|
||||
Contact
|
||||
</Button>
|
||||
</Link>
|
||||
@@ -277,7 +277,7 @@ export default function NavClient({
|
||||
>
|
||||
<SheetContent
|
||||
side='right'
|
||||
className='flex h-full w-[280px] flex-col border-[#181818] border-l bg-[#0C0C0C] p-6 pt-6 text-white shadow-xl sm:w-[320px] [&>button]:hidden'
|
||||
className='flex h-full w-[280px] flex-col border-[#181818] border-l bg-[var(--brand-background-hex)] p-6 pt-6 text-white shadow-xl sm:w-[320px] [&>button]:hidden'
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
onCloseAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
@@ -311,7 +311,7 @@ export default function NavClient({
|
||||
target='_blank'
|
||||
rel='noopener noreferrer'
|
||||
>
|
||||
<Button className='w-full bg-[#701ffc] py-6 font-medium text-base text-white shadow-[#701ffc]/20 shadow-lg transition-colors duration-200 hover:bg-[#802FFF]'>
|
||||
<Button className='w-full bg-[var(--brand-primary-hex)] py-6 font-medium text-base text-white shadow-[var(--brand-primary-hex)]/20 shadow-lg transition-colors duration-200 hover:bg-[var(--brand-primary-hover-hex)]'>
|
||||
Contact
|
||||
</Button>
|
||||
</Link>
|
||||
|
||||
@@ -62,7 +62,7 @@ function Hero() {
|
||||
<Button
|
||||
variant={'secondary'}
|
||||
onClick={handleNavigate}
|
||||
className='animate-fade-in items-center bg-[#701ffc] px-7 py-6 font-[420] font-geist-sans text-lg text-neutral-100 tracking-normal shadow-[#701ffc]/30 shadow-lg hover:bg-[#802FFF]'
|
||||
className='animate-fade-in items-center bg-[var(--brand-primary-hex)] px-7 py-6 font-[420] font-geist-sans text-lg text-neutral-100 tracking-normal shadow-[var(--brand-primary-hex)]/30 shadow-lg hover:bg-[var(--brand-primary-hover-hex)]'
|
||||
aria-label='Start using the platform'
|
||||
>
|
||||
<div className='text-[1.15rem]'>Start now</div>
|
||||
@@ -104,7 +104,7 @@ function Hero() {
|
||||
className='aspect-[5/3] h-auto md:aspect-auto'
|
||||
>
|
||||
<g filter='url(#filter0_b_0_1)'>
|
||||
<ellipse cx='300' cy='240' rx='290' ry='220' fill='#0C0C0C' />
|
||||
<ellipse cx='300' cy='240' rx='290' ry='220' fill='var(--brand-background-hex)' />
|
||||
</g>
|
||||
<defs>
|
||||
<filter
|
||||
|
||||
@@ -151,7 +151,7 @@ export default function ContributorsPage() {
|
||||
)
|
||||
|
||||
return (
|
||||
<main className='relative min-h-screen bg-[#0C0C0C] font-geist-sans text-white'>
|
||||
<main className='relative min-h-screen bg-[var(--brand-background-hex)] font-geist-sans text-white'>
|
||||
{/* Grid pattern background */}
|
||||
<div className='absolute inset-0 bottom-[400px] z-0'>
|
||||
<GridPattern
|
||||
@@ -239,7 +239,7 @@ export default function ContributorsPage() {
|
||||
<div className='mb-6 grid grid-cols-1 gap-3 sm:mb-8 sm:grid-cols-2 sm:gap-4 lg:grid-cols-5'>
|
||||
<div className='rounded-lg border border-[#606060]/20 bg-neutral-800/30 p-3 text-center sm:rounded-xl sm:p-4'>
|
||||
<div className='mb-1 flex items-center justify-center sm:mb-2'>
|
||||
<Star className='h-4 w-4 text-[#701ffc] sm:h-5 sm:w-5' />
|
||||
<Star className='h-4 w-4 text-[var(--brand-primary-hex)] sm:h-5 sm:w-5' />
|
||||
</div>
|
||||
<div className='font-bold text-lg text-white sm:text-xl'>{repoStats.stars}</div>
|
||||
<div className='text-neutral-400 text-xs'>Stars</div>
|
||||
@@ -247,7 +247,7 @@ export default function ContributorsPage() {
|
||||
|
||||
<div className='rounded-lg border border-[#606060]/20 bg-neutral-800/30 p-3 text-center sm:rounded-xl sm:p-4'>
|
||||
<div className='mb-1 flex items-center justify-center sm:mb-2'>
|
||||
<GitFork className='h-4 w-4 text-[#701ffc] sm:h-5 sm:w-5' />
|
||||
<GitFork className='h-4 w-4 text-[var(--brand-primary-hex)] sm:h-5 sm:w-5' />
|
||||
</div>
|
||||
<div className='font-bold text-lg text-white sm:text-xl'>{repoStats.forks}</div>
|
||||
<div className='text-neutral-400 text-xs'>Forks</div>
|
||||
@@ -255,7 +255,7 @@ export default function ContributorsPage() {
|
||||
|
||||
<div className='rounded-lg border border-[#606060]/20 bg-neutral-800/30 p-3 text-center sm:rounded-xl sm:p-4'>
|
||||
<div className='mb-1 flex items-center justify-center sm:mb-2'>
|
||||
<GitGraph className='h-4 w-4 text-[#701ffc] sm:h-5 sm:w-5' />
|
||||
<GitGraph className='h-4 w-4 text-[var(--brand-primary-hex)] sm:h-5 sm:w-5' />
|
||||
</div>
|
||||
<div className='font-bold text-lg text-white sm:text-xl'>
|
||||
{filteredContributors?.length || 0}
|
||||
@@ -265,7 +265,7 @@ export default function ContributorsPage() {
|
||||
|
||||
<div className='rounded-lg border border-[#606060]/20 bg-neutral-800/30 p-3 text-center sm:rounded-xl sm:p-4'>
|
||||
<div className='mb-1 flex items-center justify-center sm:mb-2'>
|
||||
<MessageCircle className='h-4 w-4 text-[#701ffc] sm:h-5 sm:w-5' />
|
||||
<MessageCircle className='h-4 w-4 text-[var(--brand-primary-hex)] sm:h-5 sm:w-5' />
|
||||
</div>
|
||||
<div className='font-bold text-lg text-white sm:text-xl'>
|
||||
{repoStats.openIssues}
|
||||
@@ -275,7 +275,7 @@ export default function ContributorsPage() {
|
||||
|
||||
<div className='rounded-lg border border-[#606060]/20 bg-neutral-800/30 p-3 text-center sm:rounded-xl sm:p-4'>
|
||||
<div className='mb-1 flex items-center justify-center sm:mb-2'>
|
||||
<GitPullRequest className='h-4 w-4 text-[#701ffc] sm:h-5 sm:w-5' />
|
||||
<GitPullRequest className='h-4 w-4 text-[var(--brand-primary-hex)] sm:h-5 sm:w-5' />
|
||||
</div>
|
||||
<div className='font-bold text-lg text-white sm:text-xl'>{repoStats.openPRs}</div>
|
||||
<div className='text-neutral-400 text-xs'>Pull Requests</div>
|
||||
@@ -291,8 +291,8 @@ export default function ContributorsPage() {
|
||||
<AreaChart data={timelineData} className='-mx-2 sm:-mx-5 mt-1 sm:mt-2'>
|
||||
<defs>
|
||||
<linearGradient id='commits' x1='0' y1='0' x2='0' y2='1'>
|
||||
<stop offset='5%' stopColor='#701ffc' stopOpacity={0.3} />
|
||||
<stop offset='95%' stopColor='#701ffc' stopOpacity={0} />
|
||||
<stop offset='5%' stopColor='var(--brand-primary-hex)' stopOpacity={0.3} />
|
||||
<stop offset='95%' stopColor='var(--brand-primary-hex)' stopOpacity={0} />
|
||||
</linearGradient>
|
||||
</defs>
|
||||
<XAxis
|
||||
@@ -320,7 +320,7 @@ export default function ContributorsPage() {
|
||||
<div className='rounded-lg border border-[#606060]/30 bg-[#0f0f0f] p-2 shadow-lg backdrop-blur-sm sm:p-3'>
|
||||
<div className='grid gap-1 sm:gap-2'>
|
||||
<div className='flex items-center gap-1 sm:gap-2'>
|
||||
<GitGraph className='h-3 w-3 text-[#701ffc] sm:h-4 sm:w-4' />
|
||||
<GitGraph className='h-3 w-3 text-[var(--brand-primary-hex)] sm:h-4 sm:w-4' />
|
||||
<span className='text-neutral-400 text-xs sm:text-sm'>
|
||||
Commits:
|
||||
</span>
|
||||
@@ -338,7 +338,7 @@ export default function ContributorsPage() {
|
||||
<Area
|
||||
type='monotone'
|
||||
dataKey='commits'
|
||||
stroke='#701ffc'
|
||||
stroke='var(--brand-primary-hex)'
|
||||
strokeWidth={2}
|
||||
fill='url(#commits)'
|
||||
/>
|
||||
@@ -393,7 +393,7 @@ export default function ContributorsPage() {
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
style={{ animationDelay: `${index * 50}ms` }}
|
||||
>
|
||||
<Avatar className='h-12 w-12 ring-2 ring-[#606060]/30 transition-transform group-hover:scale-105 group-hover:ring-[#701ffc]/60 sm:h-16 sm:w-16'>
|
||||
<Avatar className='h-12 w-12 ring-2 ring-[#606060]/30 transition-transform group-hover:scale-105 group-hover:ring-[var(--brand-primary-hex)]/60 sm:h-16 sm:w-16'>
|
||||
<AvatarImage
|
||||
src={contributor.avatar_url}
|
||||
alt={contributor.login}
|
||||
@@ -405,13 +405,13 @@ export default function ContributorsPage() {
|
||||
</Avatar>
|
||||
|
||||
<div className='mt-2 text-center sm:mt-3'>
|
||||
<span className='block font-medium text-white text-xs transition-colors group-hover:text-[#701ffc] sm:text-sm'>
|
||||
<span className='block font-medium text-white text-xs transition-colors group-hover:text-[var(--brand-primary-hex)] sm:text-sm'>
|
||||
{contributor.login.length > 12
|
||||
? `${contributor.login.slice(0, 12)}...`
|
||||
: contributor.login}
|
||||
</span>
|
||||
<div className='mt-1 flex items-center justify-center gap-1 sm:mt-2'>
|
||||
<GitGraph className='h-2 w-2 text-neutral-400 transition-colors group-hover:text-[#701ffc] sm:h-3 sm:w-3' />
|
||||
<GitGraph className='h-2 w-2 text-neutral-400 transition-colors group-hover:text-[var(--brand-primary-hex)] sm:h-3 sm:w-3' />
|
||||
<span className='font-medium text-neutral-300 text-xs transition-colors group-hover:text-white sm:text-sm'>
|
||||
{contributor.contributions}
|
||||
</span>
|
||||
@@ -508,7 +508,7 @@ export default function ContributorsPage() {
|
||||
/>
|
||||
<Bar
|
||||
dataKey='contributions'
|
||||
className='fill-[#701ffc]'
|
||||
className='fill-[var(--brand-primary-hex)]'
|
||||
radius={[4, 4, 0, 0]}
|
||||
/>
|
||||
</BarChart>
|
||||
@@ -532,7 +532,7 @@ export default function ContributorsPage() {
|
||||
>
|
||||
<div className='relative p-6 sm:p-8 md:p-12 lg:p-16'>
|
||||
<div className='text-center'>
|
||||
<div className='mb-4 inline-flex items-center rounded-full border border-[#701ffc]/20 bg-[#701ffc]/10 px-3 py-1 font-medium text-[#701ffc] text-xs sm:mb-6 sm:px-4 sm:py-2 sm:text-sm'>
|
||||
<div className='mb-4 inline-flex items-center rounded-full border border-[var(--brand-primary-hex)]/20 bg-[var(--brand-primary-hex)]/10 px-3 py-1 font-medium text-[var(--brand-primary-hex)] text-xs sm:mb-6 sm:px-4 sm:py-2 sm:text-sm'>
|
||||
<Github className='mr-1 h-3 w-3 sm:mr-2 sm:h-4 sm:w-4' />
|
||||
Apache-2.0 Licensed
|
||||
</div>
|
||||
@@ -550,7 +550,7 @@ export default function ContributorsPage() {
|
||||
<Button
|
||||
asChild
|
||||
size='lg'
|
||||
className='bg-[#701ffc] text-white transition-colors duration-500 hover:bg-[#802FFF]'
|
||||
className='bg-[var(--brand-primary-hex)] text-white transition-colors duration-500 hover:bg-[var(--brand-primary-hover-hex)]'
|
||||
>
|
||||
<a
|
||||
href='https://github.com/simstudioai/sim/blob/main/.github/CONTRIBUTING.md'
|
||||
|
||||
@@ -12,7 +12,7 @@ export default function Landing() {
|
||||
}
|
||||
|
||||
return (
|
||||
<main className='relative min-h-screen bg-[#0C0C0C] font-geist-sans'>
|
||||
<main className='relative min-h-screen bg-[var(--brand-background-hex)] font-geist-sans'>
|
||||
<NavWrapper onOpenTypeformLink={handleOpenTypeformLink} />
|
||||
|
||||
<Hero />
|
||||
|
||||
@@ -11,7 +11,7 @@ export default function PrivacyPolicy() {
|
||||
}
|
||||
|
||||
return (
|
||||
<main className='relative min-h-screen overflow-hidden bg-[#0C0C0C] text-white'>
|
||||
<main className='relative min-h-screen overflow-hidden bg-[var(--brand-background-hex)] text-white'>
|
||||
{/* Grid pattern background - only covers content area */}
|
||||
<div className='absolute inset-0 bottom-[400px] z-0 overflow-hidden'>
|
||||
<GridPattern
|
||||
@@ -42,7 +42,7 @@ export default function PrivacyPolicy() {
|
||||
className='h-full w-full'
|
||||
>
|
||||
<g filter='url(#filter0_b_privacy)'>
|
||||
<rect width='600' height='1600' rx='0' fill='#0C0C0C' />
|
||||
<rect width='600' height='1600' rx='0' fill='var(--brand-background-hex)' />
|
||||
</g>
|
||||
<defs>
|
||||
<filter
|
||||
@@ -391,7 +391,7 @@ export default function PrivacyPolicy() {
|
||||
Privacy & Terms web page:{' '}
|
||||
<Link
|
||||
href='https://policies.google.com/privacy?hl=en'
|
||||
className='text-[#B5A1D4] hover:text-[#701ffc]'
|
||||
className='text-[#B5A1D4] hover:text-[var(--brand-primary-hex)]'
|
||||
target='_blank'
|
||||
rel='noopener noreferrer'
|
||||
>
|
||||
@@ -569,7 +569,7 @@ export default function PrivacyPolicy() {
|
||||
Please note that we may ask you to verify your identity before responding to such
|
||||
requests.
|
||||
</p>
|
||||
<p className='mb-4 border-[#701ffc] border-l-4 bg-[#701ffc]/10 p-3'>
|
||||
<p className='mb-4 border-[var(--brand-primary-hex)] border-l-4 bg-[var(--brand-primary-hex)]/10 p-3'>
|
||||
You have the right to complain to a Data Protection Authority about our collection
|
||||
and use of your Personal Information. For more information, please contact your
|
||||
local data protection authority in the European Economic Area (EEA).
|
||||
@@ -661,7 +661,7 @@ export default function PrivacyPolicy() {
|
||||
policy (if any). Before beginning your inquiry, email us at{' '}
|
||||
<Link
|
||||
href='mailto:security@sim.ai'
|
||||
className='text-[#B5A1D4] hover:text-[#701ffc]'
|
||||
className='text-[#B5A1D4] hover:text-[var(--brand-primary-hex)]'
|
||||
>
|
||||
security@sim.ai
|
||||
</Link>{' '}
|
||||
@@ -686,7 +686,7 @@ export default function PrivacyPolicy() {
|
||||
To report any security flaws, send an email to{' '}
|
||||
<Link
|
||||
href='mailto:security@sim.ai'
|
||||
className='text-[#B5A1D4] hover:text-[#701ffc]'
|
||||
className='text-[#B5A1D4] hover:text-[var(--brand-primary-hex)]'
|
||||
>
|
||||
security@sim.ai
|
||||
</Link>
|
||||
@@ -726,7 +726,7 @@ export default function PrivacyPolicy() {
|
||||
If you have any questions about this Privacy Policy, please contact us at:{' '}
|
||||
<Link
|
||||
href='mailto:privacy@sim.ai'
|
||||
className='text-[#B5A1D4] hover:text-[#701ffc]'
|
||||
className='text-[#B5A1D4] hover:text-[var(--brand-primary-hex)]'
|
||||
>
|
||||
privacy@sim.ai
|
||||
</Link>
|
||||
|
||||
@@ -11,7 +11,7 @@ export default function TermsOfService() {
|
||||
}
|
||||
|
||||
return (
|
||||
<main className='relative min-h-screen overflow-hidden bg-[#0C0C0C] text-white'>
|
||||
<main className='relative min-h-screen overflow-hidden bg-[var(--brand-background-hex)] text-white'>
|
||||
{/* Grid pattern background */}
|
||||
<div className='absolute inset-0 bottom-[400px] z-0 overflow-hidden'>
|
||||
<GridPattern
|
||||
@@ -42,7 +42,7 @@ export default function TermsOfService() {
|
||||
className='h-full w-full'
|
||||
>
|
||||
<g filter='url(#filter0_b_terms)'>
|
||||
<rect width='600' height='1600' rx='0' fill='#0C0C0C' />
|
||||
<rect width='600' height='1600' rx='0' fill='var(--brand-background-hex)' />
|
||||
</g>
|
||||
<defs>
|
||||
<filter
|
||||
@@ -268,7 +268,7 @@ export default function TermsOfService() {
|
||||
Arbitration Agreement. The arbitration will be conducted by JAMS, an established
|
||||
alternative dispute resolution provider.
|
||||
</p>
|
||||
<p className='mb-4 border-[#701ffc] border-l-4 bg-[#701ffc]/10 p-3'>
|
||||
<p className='mb-4 border-[var(--brand-primary-hex)] border-l-4 bg-[var(--brand-primary-hex)]/10 p-3'>
|
||||
YOU AND COMPANY AGREE THAT EACH OF US MAY BRING CLAIMS AGAINST THE OTHER ONLY ON
|
||||
AN INDIVIDUAL BASIS AND NOT ON A CLASS, REPRESENTATIVE, OR COLLECTIVE BASIS. ONLY
|
||||
INDIVIDUAL RELIEF IS AVAILABLE, AND DISPUTES OF MORE THAN ONE CUSTOMER OR USER
|
||||
@@ -277,7 +277,10 @@ export default function TermsOfService() {
|
||||
<p className='mb-4'>
|
||||
You have the right to opt out of the provisions of this Arbitration Agreement by
|
||||
sending a timely written notice of your decision to opt out to:{' '}
|
||||
<Link href='mailto:legal@sim.ai' className='text-[#B5A1D4] hover:text-[#701ffc]'>
|
||||
<Link
|
||||
href='mailto:legal@sim.ai'
|
||||
className='text-[#B5A1D4] hover:text-[var(--brand-primary-hex)]'
|
||||
>
|
||||
legal@sim.ai{' '}
|
||||
</Link>
|
||||
within 30 days after first becoming subject to this Arbitration Agreement.
|
||||
@@ -330,7 +333,7 @@ export default function TermsOfService() {
|
||||
Our Copyright Agent can be reached at:{' '}
|
||||
<Link
|
||||
href='mailto:copyright@sim.ai'
|
||||
className='text-[#B5A1D4] hover:text-[#701ffc]'
|
||||
className='text-[#B5A1D4] hover:text-[var(--brand-primary-hex)]'
|
||||
>
|
||||
copyright@sim.ai
|
||||
</Link>
|
||||
@@ -341,7 +344,10 @@ export default function TermsOfService() {
|
||||
<h2 className='mb-4 font-semibold text-2xl text-white'>12. Contact Us</h2>
|
||||
<p>
|
||||
If you have any questions about these Terms, please contact us at:{' '}
|
||||
<Link href='mailto:legal@sim.ai' className='text-[#B5A1D4] hover:text-[#701ffc]'>
|
||||
<Link
|
||||
href='mailto:legal@sim.ai'
|
||||
className='text-[#B5A1D4] hover:text-[var(--brand-primary-hex)]'
|
||||
>
|
||||
legal@sim.ai
|
||||
</Link>
|
||||
</p>
|
||||
|
||||
@@ -354,6 +354,18 @@ export function mockExecutionDependencies() {
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock Trigger.dev SDK (tasks.trigger and task factory) for tests that import background modules
|
||||
*/
|
||||
export function mockTriggerDevSdk() {
|
||||
vi.mock('@trigger.dev/sdk', () => ({
|
||||
tasks: {
|
||||
trigger: vi.fn().mockResolvedValue({ id: 'mock-task-id' }),
|
||||
},
|
||||
task: vi.fn().mockReturnValue({}),
|
||||
}))
|
||||
}
|
||||
|
||||
export function mockWorkflowAccessValidation(shouldSucceed = true) {
|
||||
if (shouldSucceed) {
|
||||
vi.mock('@/app/api/workflows/middleware', () => ({
|
||||
|
||||
@@ -28,7 +28,12 @@ export async function POST(request: NextRequest) {
|
||||
return NextResponse.json({ error: 'Credential ID is required' }, { status: 400 })
|
||||
}
|
||||
|
||||
const authz = await authorizeCredentialUse(request, { credentialId, workflowId })
|
||||
// We already have workflowId from the parsed body; avoid forcing hybrid auth to re-read it
|
||||
const authz = await authorizeCredentialUse(request, {
|
||||
credentialId,
|
||||
workflowId,
|
||||
requireWorkflowIdForInternal: false,
|
||||
})
|
||||
if (!authz.ok || !authz.credentialOwnerUserId) {
|
||||
return NextResponse.json({ error: authz.error || 'Unauthorized' }, { status: 403 })
|
||||
}
|
||||
@@ -79,14 +84,12 @@ export async function GET(request: NextRequest) {
|
||||
return NextResponse.json({ error: 'Credential not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
// Check if the access token is valid
|
||||
if (!credential.accessToken) {
|
||||
logger.warn(`[${requestId}] No access token available for credential`)
|
||||
return NextResponse.json({ error: 'No access token available' }, { status: 400 })
|
||||
}
|
||||
|
||||
try {
|
||||
// Refresh the token if needed
|
||||
const { accessToken } = await refreshTokenIfNeeded(requestId, credential, credentialId)
|
||||
return NextResponse.json({ accessToken }, { status: 200 })
|
||||
} catch (_error) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { and, desc, eq } from 'drizzle-orm'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { refreshOAuthToken } from '@/lib/oauth/oauth'
|
||||
@@ -70,7 +70,8 @@ export async function getOAuthToken(userId: string, providerId: string): Promise
|
||||
})
|
||||
.from(account)
|
||||
.where(and(eq(account.userId, userId), eq(account.providerId, providerId)))
|
||||
.orderBy(account.createdAt)
|
||||
// Always use the most recently updated credential for this provider
|
||||
.orderBy(desc(account.updatedAt))
|
||||
.limit(1)
|
||||
|
||||
if (connections.length === 0) {
|
||||
@@ -80,19 +81,13 @@ export async function getOAuthToken(userId: string, providerId: string): Promise
|
||||
|
||||
const credential = connections[0]
|
||||
|
||||
// Check if we have a valid access token
|
||||
if (!credential.accessToken) {
|
||||
logger.warn(`Access token is null for user ${userId}, provider ${providerId}`)
|
||||
return null
|
||||
}
|
||||
|
||||
// Check if the token is expired and needs refreshing
|
||||
// Determine whether we should refresh: missing token OR expired token
|
||||
const now = new Date()
|
||||
const tokenExpiry = credential.accessTokenExpiresAt
|
||||
// Only refresh if we have an expiration time AND it's expired AND we have a refresh token
|
||||
const needsRefresh = tokenExpiry && tokenExpiry < now && !!credential.refreshToken
|
||||
const shouldAttemptRefresh =
|
||||
!!credential.refreshToken && (!credential.accessToken || (tokenExpiry && tokenExpiry < now))
|
||||
|
||||
if (needsRefresh) {
|
||||
if (shouldAttemptRefresh) {
|
||||
logger.info(
|
||||
`Access token expired for user ${userId}, provider ${providerId}. Attempting to refresh.`
|
||||
)
|
||||
@@ -141,6 +136,13 @@ export async function getOAuthToken(userId: string, providerId: string): Promise
|
||||
}
|
||||
}
|
||||
|
||||
if (!credential.accessToken) {
|
||||
logger.warn(
|
||||
`Access token is null and no refresh attempted or available for user ${userId}, provider ${providerId}`
|
||||
)
|
||||
return null
|
||||
}
|
||||
|
||||
logger.info(`Found valid OAuth token for user ${userId}, provider ${providerId}`)
|
||||
return credential.accessToken
|
||||
}
|
||||
@@ -164,19 +166,21 @@ export async function refreshAccessTokenIfNeeded(
|
||||
return null
|
||||
}
|
||||
|
||||
// Check if we need to refresh the token
|
||||
// Decide if we should refresh: token missing OR expired
|
||||
const expiresAt = credential.accessTokenExpiresAt
|
||||
const now = new Date()
|
||||
// Only refresh if we have an expiration time AND it's expired
|
||||
// If no expiration time is set (newly created credentials), assume token is valid
|
||||
const needsRefresh = expiresAt && expiresAt <= now
|
||||
const shouldRefresh =
|
||||
!!credential.refreshToken && (!credential.accessToken || (expiresAt && expiresAt <= now))
|
||||
|
||||
const accessToken = credential.accessToken
|
||||
|
||||
if (needsRefresh && credential.refreshToken) {
|
||||
if (shouldRefresh) {
|
||||
logger.info(`[${requestId}] Token expired, attempting to refresh for credential`)
|
||||
try {
|
||||
const refreshedToken = await refreshOAuthToken(credential.providerId, credential.refreshToken)
|
||||
const refreshedToken = await refreshOAuthToken(
|
||||
credential.providerId,
|
||||
credential.refreshToken!
|
||||
)
|
||||
|
||||
if (!refreshedToken) {
|
||||
logger.error(`[${requestId}] Failed to refresh token for credential: ${credentialId}`, {
|
||||
@@ -217,6 +221,7 @@ export async function refreshAccessTokenIfNeeded(
|
||||
return null
|
||||
}
|
||||
} else if (!accessToken) {
|
||||
// We have no access token and either no refresh token or not eligible to refresh
|
||||
logger.error(`[${requestId}] Missing access token for credential`)
|
||||
return null
|
||||
}
|
||||
@@ -233,21 +238,20 @@ export async function refreshTokenIfNeeded(
|
||||
credential: any,
|
||||
credentialId: string
|
||||
): Promise<{ accessToken: string; refreshed: boolean }> {
|
||||
// Check if we need to refresh the token
|
||||
// Decide if we should refresh: token missing OR expired
|
||||
const expiresAt = credential.accessTokenExpiresAt
|
||||
const now = new Date()
|
||||
// Only refresh if we have an expiration time AND it's expired
|
||||
// If no expiration time is set (newly created credentials), assume token is valid
|
||||
const needsRefresh = expiresAt && expiresAt <= now
|
||||
const shouldRefresh =
|
||||
!!credential.refreshToken && (!credential.accessToken || (expiresAt && expiresAt <= now))
|
||||
|
||||
// If token is still valid, return it directly
|
||||
if (!needsRefresh || !credential.refreshToken) {
|
||||
// If token appears valid and present, return it directly
|
||||
if (!shouldRefresh) {
|
||||
logger.info(`[${requestId}] Access token is valid`)
|
||||
return { accessToken: credential.accessToken, refreshed: false }
|
||||
}
|
||||
|
||||
try {
|
||||
const refreshResult = await refreshOAuthToken(credential.providerId, credential.refreshToken)
|
||||
const refreshResult = await refreshOAuthToken(credential.providerId, credential.refreshToken!)
|
||||
|
||||
if (!refreshResult) {
|
||||
logger.error(`[${requestId}] Failed to refresh token for credential`)
|
||||
|
||||
@@ -4,8 +4,9 @@ import { auth } from '@/lib/auth'
|
||||
|
||||
export async function POST() {
|
||||
try {
|
||||
const hdrs = await headers()
|
||||
const response = await auth.api.generateOneTimeToken({
|
||||
headers: await headers(),
|
||||
headers: hdrs,
|
||||
})
|
||||
|
||||
if (!response) {
|
||||
@@ -14,7 +15,6 @@ export async function POST() {
|
||||
|
||||
return NextResponse.json({ token: response.token })
|
||||
} catch (error) {
|
||||
console.error('Error generating one-time token:', error)
|
||||
return NextResponse.json({ error: 'Failed to generate token' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
|
||||
7
apps/sim/app/api/auth/webhook/stripe/route.ts
Normal file
7
apps/sim/app/api/auth/webhook/stripe/route.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
import { toNextJsHandler } from 'better-auth/next-js'
|
||||
import { auth } from '@/lib/auth'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
// Handle Stripe webhooks through better-auth
|
||||
export const { GET, POST } = toNextJsHandler(auth.handler)
|
||||
@@ -1,109 +0,0 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { verifyCronAuth } from '@/lib/auth/internal'
|
||||
import { processDailyBillingCheck } from '@/lib/billing/core/billing'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('DailyBillingCron')
|
||||
|
||||
/**
|
||||
* Daily billing CRON job endpoint that checks individual billing periods
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const authError = verifyCronAuth(request, 'daily billing check')
|
||||
if (authError) {
|
||||
return authError
|
||||
}
|
||||
|
||||
logger.info('Starting daily billing check cron job')
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
// Process overage billing for users and organizations with periods ending today
|
||||
const result = await processDailyBillingCheck()
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
|
||||
if (result.success) {
|
||||
logger.info('Daily billing check completed successfully', {
|
||||
processedUsers: result.processedUsers,
|
||||
processedOrganizations: result.processedOrganizations,
|
||||
totalChargedAmount: result.totalChargedAmount,
|
||||
duration: `${duration}ms`,
|
||||
})
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
summary: {
|
||||
processedUsers: result.processedUsers,
|
||||
processedOrganizations: result.processedOrganizations,
|
||||
totalChargedAmount: result.totalChargedAmount,
|
||||
duration: `${duration}ms`,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
logger.error('Daily billing check completed with errors', {
|
||||
processedUsers: result.processedUsers,
|
||||
processedOrganizations: result.processedOrganizations,
|
||||
totalChargedAmount: result.totalChargedAmount,
|
||||
errorCount: result.errors.length,
|
||||
errors: result.errors,
|
||||
duration: `${duration}ms`,
|
||||
})
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
summary: {
|
||||
processedUsers: result.processedUsers,
|
||||
processedOrganizations: result.processedOrganizations,
|
||||
totalChargedAmount: result.totalChargedAmount,
|
||||
errorCount: result.errors.length,
|
||||
duration: `${duration}ms`,
|
||||
},
|
||||
errors: result.errors,
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Fatal error in monthly billing cron job', { error })
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: 'Internal server error during daily billing check',
|
||||
details: error instanceof Error ? error.message : 'Unknown error',
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* GET endpoint for manual testing and health checks
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
const authError = verifyCronAuth(request, 'daily billing check health check')
|
||||
if (authError) {
|
||||
return authError
|
||||
}
|
||||
|
||||
return NextResponse.json({
|
||||
status: 'ready',
|
||||
message:
|
||||
'Daily billing check cron job is ready to process users and organizations with periods ending today',
|
||||
currentDate: new Date().toISOString().split('T')[0],
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Error in billing health check', { error })
|
||||
return NextResponse.json(
|
||||
{
|
||||
status: 'error',
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
77
apps/sim/app/api/billing/portal/route.ts
Normal file
77
apps/sim/app/api/billing/portal/route.ts
Normal file
@@ -0,0 +1,77 @@
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { requireStripeClient } from '@/lib/billing/stripe-client'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { subscription as subscriptionTable, user } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('BillingPortal')
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const session = await getSession()
|
||||
|
||||
try {
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const body = await request.json().catch(() => ({}))
|
||||
const context: 'user' | 'organization' =
|
||||
body?.context === 'organization' ? 'organization' : 'user'
|
||||
const organizationId: string | undefined = body?.organizationId || undefined
|
||||
const returnUrl: string =
|
||||
body?.returnUrl || `${env.NEXT_PUBLIC_APP_URL}/workspace?billing=updated`
|
||||
|
||||
const stripe = requireStripeClient()
|
||||
|
||||
let stripeCustomerId: string | null = null
|
||||
|
||||
if (context === 'organization') {
|
||||
if (!organizationId) {
|
||||
return NextResponse.json({ error: 'organizationId is required' }, { status: 400 })
|
||||
}
|
||||
|
||||
const rows = await db
|
||||
.select({ customer: subscriptionTable.stripeCustomerId })
|
||||
.from(subscriptionTable)
|
||||
.where(
|
||||
and(
|
||||
eq(subscriptionTable.referenceId, organizationId),
|
||||
eq(subscriptionTable.status, 'active')
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
stripeCustomerId = rows.length > 0 ? rows[0].customer || null : null
|
||||
} else {
|
||||
const rows = await db
|
||||
.select({ customer: user.stripeCustomerId })
|
||||
.from(user)
|
||||
.where(eq(user.id, session.user.id))
|
||||
.limit(1)
|
||||
|
||||
stripeCustomerId = rows.length > 0 ? rows[0].customer || null : null
|
||||
}
|
||||
|
||||
if (!stripeCustomerId) {
|
||||
logger.error('Stripe customer not found for portal session', {
|
||||
context,
|
||||
organizationId,
|
||||
userId: session.user.id,
|
||||
})
|
||||
return NextResponse.json({ error: 'Stripe customer not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
const portal = await stripe.billingPortal.sessions.create({
|
||||
customer: stripeCustomerId,
|
||||
return_url: returnUrl,
|
||||
})
|
||||
|
||||
return NextResponse.json({ url: portal.url })
|
||||
} catch (error) {
|
||||
logger.error('Failed to create billing portal session', { error })
|
||||
return NextResponse.json({ error: 'Failed to create billing portal session' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,10 @@ import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getSimplifiedBillingSummary } from '@/lib/billing/core/billing'
|
||||
import { getOrganizationBillingData } from '@/lib/billing/core/organization-billing'
|
||||
import { getOrganizationBillingData } from '@/lib/billing/core/organization'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { member } from '@/db/schema'
|
||||
import { member, userStats } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('UnifiedBillingAPI')
|
||||
|
||||
@@ -45,6 +45,16 @@ export async function GET(request: NextRequest) {
|
||||
if (context === 'user') {
|
||||
// Get user billing (may include organization if they're part of one)
|
||||
billingData = await getSimplifiedBillingSummary(session.user.id, contextId || undefined)
|
||||
// Attach billingBlocked status for the current user
|
||||
const stats = await db
|
||||
.select({ blocked: userStats.billingBlocked })
|
||||
.from(userStats)
|
||||
.where(eq(userStats.userId, session.user.id))
|
||||
.limit(1)
|
||||
billingData = {
|
||||
...billingData,
|
||||
billingBlocked: stats.length > 0 ? !!stats[0].blocked : false,
|
||||
}
|
||||
} else {
|
||||
// Get user role in organization for permission checks first
|
||||
const memberRecord = await db
|
||||
@@ -78,8 +88,10 @@ export async function GET(request: NextRequest) {
|
||||
subscriptionStatus: rawBillingData.subscriptionStatus,
|
||||
totalSeats: rawBillingData.totalSeats,
|
||||
usedSeats: rawBillingData.usedSeats,
|
||||
seatsCount: rawBillingData.seatsCount,
|
||||
totalCurrentUsage: rawBillingData.totalCurrentUsage,
|
||||
totalUsageLimit: rawBillingData.totalUsageLimit,
|
||||
minimumBillingAmount: rawBillingData.minimumBillingAmount,
|
||||
averageUsagePerMember: rawBillingData.averageUsagePerMember,
|
||||
billingPeriodStart: rawBillingData.billingPeriodStart?.toISOString() || null,
|
||||
billingPeriodEnd: rawBillingData.billingPeriodEnd?.toISOString() || null,
|
||||
@@ -92,11 +104,25 @@ export async function GET(request: NextRequest) {
|
||||
|
||||
const userRole = memberRecord[0].role
|
||||
|
||||
// Include the requesting user's blocked flag as well so UI can reflect it
|
||||
const stats = await db
|
||||
.select({ blocked: userStats.billingBlocked })
|
||||
.from(userStats)
|
||||
.where(eq(userStats.userId, session.user.id))
|
||||
.limit(1)
|
||||
|
||||
// Merge blocked flag into data for convenience
|
||||
billingData = {
|
||||
...billingData,
|
||||
billingBlocked: stats.length > 0 ? !!stats[0].blocked : false,
|
||||
}
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
context,
|
||||
data: billingData,
|
||||
userRole,
|
||||
billingBlocked: billingData.billingBlocked,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ import crypto from 'crypto'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { env } from '@/lib/env'
|
||||
import { isProd } from '@/lib/environment'
|
||||
import { checkInternalApiKey } from '@/lib/copilot/utils'
|
||||
import { isBillingEnabled } from '@/lib/environment'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { userStats } from '@/db/schema'
|
||||
@@ -11,34 +11,15 @@ import { calculateCost } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('billing-update-cost')
|
||||
|
||||
// Schema for the request body
|
||||
const UpdateCostSchema = z.object({
|
||||
userId: z.string().min(1, 'User ID is required'),
|
||||
input: z.number().min(0, 'Input tokens must be a non-negative number'),
|
||||
output: z.number().min(0, 'Output tokens must be a non-negative number'),
|
||||
model: z.string().min(1, 'Model is required'),
|
||||
inputMultiplier: z.number().min(0),
|
||||
outputMultiplier: z.number().min(0),
|
||||
})
|
||||
|
||||
// Authentication function (reused from copilot/methods route)
|
||||
function checkInternalApiKey(req: NextRequest) {
|
||||
const apiKey = req.headers.get('x-api-key')
|
||||
const expectedApiKey = env.INTERNAL_API_SECRET
|
||||
|
||||
if (!expectedApiKey) {
|
||||
return { success: false, error: 'Internal API key not configured' }
|
||||
}
|
||||
|
||||
if (!apiKey) {
|
||||
return { success: false, error: 'API key required' }
|
||||
}
|
||||
|
||||
if (apiKey !== expectedApiKey) {
|
||||
return { success: false, error: 'Invalid API key' }
|
||||
}
|
||||
|
||||
return { success: true }
|
||||
}
|
||||
|
||||
/**
|
||||
* POST /api/billing/update-cost
|
||||
* Update user cost based on token usage with internal API key auth
|
||||
@@ -50,6 +31,19 @@ export async function POST(req: NextRequest) {
|
||||
try {
|
||||
logger.info(`[${requestId}] Update cost request started`)
|
||||
|
||||
if (!isBillingEnabled) {
|
||||
logger.debug(`[${requestId}] Billing is disabled, skipping cost update`)
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: 'Billing disabled, cost update skipped',
|
||||
data: {
|
||||
billingEnabled: false,
|
||||
processedAt: new Date().toISOString(),
|
||||
requestId,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// Check authentication (internal API key)
|
||||
const authResult = checkInternalApiKey(req)
|
||||
if (!authResult.success) {
|
||||
@@ -82,27 +76,29 @@ export async function POST(req: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
const { userId, input, output, model } = validation.data
|
||||
const { userId, input, output, model, inputMultiplier, outputMultiplier } = validation.data
|
||||
|
||||
logger.info(`[${requestId}] Processing cost update`, {
|
||||
userId,
|
||||
input,
|
||||
output,
|
||||
model,
|
||||
inputMultiplier,
|
||||
outputMultiplier,
|
||||
})
|
||||
|
||||
const finalPromptTokens = input
|
||||
const finalCompletionTokens = output
|
||||
const totalTokens = input + output
|
||||
|
||||
// Calculate cost using COPILOT_COST_MULTIPLIER (only in production, like normal executions)
|
||||
const copilotMultiplier = isProd ? env.COPILOT_COST_MULTIPLIER || 1 : 1
|
||||
// Calculate cost using provided multiplier (required)
|
||||
const costResult = calculateCost(
|
||||
model,
|
||||
finalPromptTokens,
|
||||
finalCompletionTokens,
|
||||
false,
|
||||
copilotMultiplier
|
||||
inputMultiplier,
|
||||
outputMultiplier
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Cost calculation result`, {
|
||||
@@ -111,7 +107,8 @@ export async function POST(req: NextRequest) {
|
||||
promptTokens: finalPromptTokens,
|
||||
completionTokens: finalCompletionTokens,
|
||||
totalTokens: totalTokens,
|
||||
copilotMultiplier,
|
||||
inputMultiplier,
|
||||
outputMultiplier,
|
||||
costResult,
|
||||
})
|
||||
|
||||
@@ -122,44 +119,34 @@ export async function POST(req: NextRequest) {
|
||||
const userStatsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))
|
||||
|
||||
if (userStatsRecords.length === 0) {
|
||||
// Create new user stats record (same logic as ExecutionLogger)
|
||||
await db.insert(userStats).values({
|
||||
id: crypto.randomUUID(),
|
||||
userId: userId,
|
||||
totalManualExecutions: 0,
|
||||
totalApiCalls: 0,
|
||||
totalWebhookTriggers: 0,
|
||||
totalScheduledExecutions: 0,
|
||||
totalChatExecutions: 0,
|
||||
totalTokensUsed: totalTokens,
|
||||
totalCost: costToStore.toString(),
|
||||
currentPeriodCost: costToStore.toString(),
|
||||
lastActive: new Date(),
|
||||
})
|
||||
|
||||
logger.info(`[${requestId}] Created new user stats record`, {
|
||||
userId,
|
||||
totalCost: costToStore,
|
||||
totalTokens,
|
||||
})
|
||||
} else {
|
||||
// Update existing user stats record (same logic as ExecutionLogger)
|
||||
const updateFields = {
|
||||
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
|
||||
totalCost: sql`total_cost + ${costToStore}`,
|
||||
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
|
||||
totalApiCalls: sql`total_api_calls`,
|
||||
lastActive: new Date(),
|
||||
}
|
||||
|
||||
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
|
||||
|
||||
logger.info(`[${requestId}] Updated user stats record`, {
|
||||
userId,
|
||||
addedCost: costToStore,
|
||||
addedTokens: totalTokens,
|
||||
})
|
||||
logger.error(
|
||||
`[${requestId}] User stats record not found - should be created during onboarding`,
|
||||
{
|
||||
userId,
|
||||
}
|
||||
)
|
||||
return NextResponse.json({ error: 'User stats record not found' }, { status: 500 })
|
||||
}
|
||||
// Update existing user stats record (same logic as ExecutionLogger)
|
||||
const updateFields = {
|
||||
totalTokensUsed: sql`total_tokens_used + ${totalTokens}`,
|
||||
totalCost: sql`total_cost + ${costToStore}`,
|
||||
currentPeriodCost: sql`current_period_cost + ${costToStore}`,
|
||||
// Copilot usage tracking increments
|
||||
totalCopilotCost: sql`total_copilot_cost + ${costToStore}`,
|
||||
totalCopilotTokens: sql`total_copilot_tokens + ${totalTokens}`,
|
||||
totalCopilotCalls: sql`total_copilot_calls + 1`,
|
||||
totalApiCalls: sql`total_api_calls`,
|
||||
lastActive: new Date(),
|
||||
}
|
||||
|
||||
await db.update(userStats).set(updateFields).where(eq(userStats.userId, userId))
|
||||
|
||||
logger.info(`[${requestId}] Updated user stats record`, {
|
||||
userId,
|
||||
addedCost: costToStore,
|
||||
addedTokens: totalTokens,
|
||||
})
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
|
||||
|
||||
@@ -1,116 +0,0 @@
|
||||
import { headers } from 'next/headers'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import type Stripe from 'stripe'
|
||||
import { requireStripeClient } from '@/lib/billing/stripe-client'
|
||||
import { handleInvoiceWebhook } from '@/lib/billing/webhooks/stripe-invoice-webhooks'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('StripeInvoiceWebhook')
|
||||
|
||||
/**
|
||||
* Stripe billing webhook endpoint for invoice-related events
|
||||
* Endpoint: /api/billing/webhooks/stripe
|
||||
* Handles: invoice.payment_succeeded, invoice.payment_failed, invoice.finalized
|
||||
*/
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const body = await request.text()
|
||||
const headersList = await headers()
|
||||
const signature = headersList.get('stripe-signature')
|
||||
|
||||
if (!signature) {
|
||||
logger.error('Missing Stripe signature header')
|
||||
return NextResponse.json({ error: 'Missing Stripe signature' }, { status: 400 })
|
||||
}
|
||||
|
||||
if (!env.STRIPE_BILLING_WEBHOOK_SECRET) {
|
||||
logger.error('Missing Stripe webhook secret configuration')
|
||||
return NextResponse.json({ error: 'Webhook secret not configured' }, { status: 500 })
|
||||
}
|
||||
|
||||
// Check if Stripe client is available
|
||||
let stripe
|
||||
try {
|
||||
stripe = requireStripeClient()
|
||||
} catch (stripeError) {
|
||||
logger.error('Stripe client not available for webhook processing', {
|
||||
error: stripeError,
|
||||
})
|
||||
return NextResponse.json({ error: 'Stripe client not configured' }, { status: 500 })
|
||||
}
|
||||
|
||||
// Verify webhook signature
|
||||
let event: Stripe.Event
|
||||
try {
|
||||
event = stripe.webhooks.constructEvent(body, signature, env.STRIPE_BILLING_WEBHOOK_SECRET)
|
||||
} catch (signatureError) {
|
||||
logger.error('Invalid Stripe webhook signature', {
|
||||
error: signatureError,
|
||||
signature,
|
||||
})
|
||||
return NextResponse.json({ error: 'Invalid signature' }, { status: 400 })
|
||||
}
|
||||
|
||||
logger.info('Received Stripe invoice webhook', {
|
||||
eventId: event.id,
|
||||
eventType: event.type,
|
||||
})
|
||||
|
||||
// Handle specific invoice events
|
||||
const supportedEvents = [
|
||||
'invoice.payment_succeeded',
|
||||
'invoice.payment_failed',
|
||||
'invoice.finalized',
|
||||
]
|
||||
|
||||
if (supportedEvents.includes(event.type)) {
|
||||
try {
|
||||
await handleInvoiceWebhook(event)
|
||||
|
||||
logger.info('Successfully processed invoice webhook', {
|
||||
eventId: event.id,
|
||||
eventType: event.type,
|
||||
})
|
||||
|
||||
return NextResponse.json({ received: true })
|
||||
} catch (processingError) {
|
||||
logger.error('Failed to process invoice webhook', {
|
||||
eventId: event.id,
|
||||
eventType: event.type,
|
||||
error: processingError,
|
||||
})
|
||||
|
||||
// Return 500 to tell Stripe to retry the webhook
|
||||
return NextResponse.json({ error: 'Failed to process webhook' }, { status: 500 })
|
||||
}
|
||||
} else {
|
||||
// Not a supported invoice event, ignore
|
||||
logger.info('Ignoring unsupported webhook event', {
|
||||
eventId: event.id,
|
||||
eventType: event.type,
|
||||
supportedEvents,
|
||||
})
|
||||
|
||||
return NextResponse.json({ received: true })
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Fatal error in invoice webhook handler', {
|
||||
error,
|
||||
url: request.url,
|
||||
})
|
||||
|
||||
return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* GET endpoint for webhook health checks
|
||||
*/
|
||||
export async function GET() {
|
||||
return NextResponse.json({
|
||||
status: 'healthy',
|
||||
webhook: 'stripe-invoices',
|
||||
events: ['invoice.payment_succeeded', 'invoice.payment_failed', 'invoice.finalized'],
|
||||
})
|
||||
}
|
||||
@@ -246,7 +246,10 @@ describe('Chat API Route', () => {
|
||||
NEXT_PUBLIC_APP_URL: 'http://localhost:3000',
|
||||
},
|
||||
isTruthy: (value: string | boolean | number | undefined) =>
|
||||
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
|
||||
typeof value === 'string'
|
||||
? value.toLowerCase() === 'true' || value === '1'
|
||||
: Boolean(value),
|
||||
getEnv: (variable: string) => process.env[variable],
|
||||
}))
|
||||
|
||||
const validData = {
|
||||
@@ -291,6 +294,7 @@ describe('Chat API Route', () => {
|
||||
},
|
||||
isTruthy: (value: string | boolean | number | undefined) =>
|
||||
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
|
||||
getEnv: (variable: string) => process.env[variable],
|
||||
}))
|
||||
|
||||
const validData = {
|
||||
|
||||
@@ -14,8 +14,6 @@ import { chat } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('ChatAPI')
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
const chatSchema = z.object({
|
||||
workflowId: z.string().min(1, 'Workflow ID is required'),
|
||||
subdomain: z
|
||||
@@ -150,7 +148,7 @@ export async function POST(request: NextRequest) {
|
||||
// Merge customizations with the additional fields
|
||||
const mergedCustomizations = {
|
||||
...(customizations || {}),
|
||||
primaryColor: customizations?.primaryColor || '#802FFF',
|
||||
primaryColor: customizations?.primaryColor || 'var(--brand-primary-hover-hex)',
|
||||
welcomeMessage: customizations?.welcomeMessage || 'Hi there! How can I help you today?',
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,6 @@ import { eq } from 'drizzle-orm'
|
||||
import { NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
import { createErrorResponse, createSuccessResponse } from '@/app/api/workflows/utils'
|
||||
import { db } from '@/db'
|
||||
import { chat } from '@/db/schema'
|
||||
@@ -48,6 +45,7 @@ export async function GET(request: Request) {
|
||||
'support',
|
||||
'admin',
|
||||
'qa',
|
||||
'agent',
|
||||
]
|
||||
if (reservedSubdomains.includes(subdomain)) {
|
||||
return NextResponse.json(
|
||||
|
||||
@@ -3,6 +3,7 @@ import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { checkServerSideUsageLimits } from '@/lib/billing'
|
||||
import { isDev } from '@/lib/environment'
|
||||
import { getPersonalAndWorkspaceEnv } from '@/lib/environment/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { LoggingSession } from '@/lib/logs/execution/logging-session'
|
||||
import { buildTraceSpans } from '@/lib/logs/execution/trace-spans/trace-spans'
|
||||
@@ -12,7 +13,7 @@ import { getEmailDomain } from '@/lib/urls/utils'
|
||||
import { decryptSecret } from '@/lib/utils'
|
||||
import { getBlock } from '@/blocks'
|
||||
import { db } from '@/db'
|
||||
import { chat, environment as envTable, userStats, workflow } from '@/db/schema'
|
||||
import { chat, userStats, workflow } from '@/db/schema'
|
||||
import { Executor } from '@/executor'
|
||||
import type { BlockLog, ExecutionResult } from '@/executor/types'
|
||||
import { Serializer } from '@/serializer'
|
||||
@@ -453,18 +454,21 @@ export async function executeWorkflowForChat(
|
||||
{} as Record<string, Record<string, any>>
|
||||
)
|
||||
|
||||
// Get user environment variables for this workflow
|
||||
// Get user environment variables with workspace precedence
|
||||
let envVars: Record<string, string> = {}
|
||||
try {
|
||||
const envResult = await db
|
||||
.select()
|
||||
.from(envTable)
|
||||
.where(eq(envTable.userId, deployment.userId))
|
||||
const wfWorkspaceRow = await db
|
||||
.select({ workspaceId: workflow.workspaceId })
|
||||
.from(workflow)
|
||||
.where(eq(workflow.id, workflowId))
|
||||
.limit(1)
|
||||
|
||||
if (envResult.length > 0 && envResult[0].variables) {
|
||||
envVars = envResult[0].variables as Record<string, string>
|
||||
}
|
||||
const workspaceId = wfWorkspaceRow[0]?.workspaceId || undefined
|
||||
const { personalEncrypted, workspaceEncrypted } = await getPersonalAndWorkspaceEnv(
|
||||
deployment.userId,
|
||||
workspaceId
|
||||
)
|
||||
envVars = { ...personalEncrypted, ...workspaceEncrypted }
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Could not fetch environment variables:`, error)
|
||||
}
|
||||
|
||||
53
apps/sim/app/api/copilot/api-keys/generate/route.ts
Normal file
53
apps/sim/app/api/copilot/api-keys/generate/route.ts
Normal file
@@ -0,0 +1,53 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
|
||||
|
||||
const logger = createLogger('CopilotApiKeysGenerate')
|
||||
|
||||
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
|
||||
|
||||
export async function POST(req: NextRequest) {
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
|
||||
const res = await fetch(`${SIM_AGENT_API_URL}/api/validate-key/generate`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
|
||||
},
|
||||
body: JSON.stringify({ userId }),
|
||||
})
|
||||
|
||||
if (!res.ok) {
|
||||
const errorBody = await res.text().catch(() => '')
|
||||
logger.error('Sim Agent generate key error', { status: res.status, error: errorBody })
|
||||
return NextResponse.json(
|
||||
{ error: 'Failed to generate copilot API key' },
|
||||
{ status: res.status || 500 }
|
||||
)
|
||||
}
|
||||
|
||||
const data = (await res.json().catch(() => null)) as { apiKey?: string } | null
|
||||
|
||||
if (!data?.apiKey) {
|
||||
logger.error('Sim Agent generate key returned invalid payload')
|
||||
return NextResponse.json({ error: 'Invalid response from Sim Agent' }, { status: 500 })
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{ success: true, key: { id: 'new', apiKey: data.apiKey } },
|
||||
{ status: 201 }
|
||||
)
|
||||
} catch (error) {
|
||||
logger.error('Failed to proxy generate copilot API key', { error })
|
||||
return NextResponse.json({ error: 'Failed to generate copilot API key' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
91
apps/sim/app/api/copilot/api-keys/route.ts
Normal file
91
apps/sim/app/api/copilot/api-keys/route.ts
Normal file
@@ -0,0 +1,91 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
|
||||
|
||||
const logger = createLogger('CopilotApiKeys')
|
||||
|
||||
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
|
||||
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
|
||||
const res = await fetch(`${SIM_AGENT_API_URL}/api/validate-key/get-api-keys`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
|
||||
},
|
||||
body: JSON.stringify({ userId }),
|
||||
})
|
||||
|
||||
if (!res.ok) {
|
||||
const errorBody = await res.text().catch(() => '')
|
||||
logger.error('Sim Agent get-api-keys error', { status: res.status, error: errorBody })
|
||||
return NextResponse.json({ error: 'Failed to get keys' }, { status: res.status || 500 })
|
||||
}
|
||||
|
||||
const apiKeys = (await res.json().catch(() => null)) as { id: string; apiKey: string }[] | null
|
||||
|
||||
if (!Array.isArray(apiKeys)) {
|
||||
logger.error('Sim Agent get-api-keys returned invalid payload')
|
||||
return NextResponse.json({ error: 'Invalid response from Sim Agent' }, { status: 500 })
|
||||
}
|
||||
|
||||
const keys = apiKeys
|
||||
|
||||
return NextResponse.json({ keys }, { status: 200 })
|
||||
} catch (error) {
|
||||
logger.error('Failed to get copilot API keys', { error })
|
||||
return NextResponse.json({ error: 'Failed to get keys' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
|
||||
export async function DELETE(request: NextRequest) {
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
const url = new URL(request.url)
|
||||
const id = url.searchParams.get('id')
|
||||
if (!id) {
|
||||
return NextResponse.json({ error: 'id is required' }, { status: 400 })
|
||||
}
|
||||
|
||||
const res = await fetch(`${SIM_AGENT_API_URL}/api/validate-key/delete`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
|
||||
},
|
||||
body: JSON.stringify({ userId, apiKeyId: id }),
|
||||
})
|
||||
|
||||
if (!res.ok) {
|
||||
const errorBody = await res.text().catch(() => '')
|
||||
logger.error('Sim Agent delete key error', { status: res.status, error: errorBody })
|
||||
return NextResponse.json({ error: 'Failed to delete key' }, { status: res.status || 500 })
|
||||
}
|
||||
|
||||
const data = (await res.json().catch(() => null)) as { success?: boolean } | null
|
||||
if (!data?.success) {
|
||||
logger.error('Sim Agent delete key returned invalid payload')
|
||||
return NextResponse.json({ error: 'Invalid response from Sim Agent' }, { status: 500 })
|
||||
}
|
||||
|
||||
return NextResponse.json({ success: true }, { status: 200 })
|
||||
} catch (error) {
|
||||
logger.error('Failed to delete copilot API key', { error })
|
||||
return NextResponse.json({ error: 'Failed to delete key' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
58
apps/sim/app/api/copilot/api-keys/validate/route.ts
Normal file
58
apps/sim/app/api/copilot/api-keys/validate/route.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { checkInternalApiKey } from '@/lib/copilot/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { userStats } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('CopilotApiKeysValidate')
|
||||
|
||||
export async function POST(req: NextRequest) {
|
||||
try {
|
||||
// Authenticate via internal API key header
|
||||
const auth = checkInternalApiKey(req)
|
||||
if (!auth.success) {
|
||||
return new NextResponse(null, { status: 401 })
|
||||
}
|
||||
|
||||
const body = await req.json().catch(() => null)
|
||||
const userId = typeof body?.userId === 'string' ? body.userId : undefined
|
||||
|
||||
if (!userId) {
|
||||
return NextResponse.json({ error: 'userId is required' }, { status: 400 })
|
||||
}
|
||||
|
||||
logger.info('[API VALIDATION] Validating usage limit', { userId })
|
||||
|
||||
const usage = await db
|
||||
.select({
|
||||
currentPeriodCost: userStats.currentPeriodCost,
|
||||
totalCost: userStats.totalCost,
|
||||
currentUsageLimit: userStats.currentUsageLimit,
|
||||
})
|
||||
.from(userStats)
|
||||
.where(eq(userStats.userId, userId))
|
||||
.limit(1)
|
||||
|
||||
logger.info('[API VALIDATION] Usage limit validated', { userId, usage })
|
||||
|
||||
if (usage.length > 0) {
|
||||
const currentUsage = Number.parseFloat(
|
||||
(usage[0].currentPeriodCost?.toString() as string) ||
|
||||
(usage[0].totalCost as unknown as string) ||
|
||||
'0'
|
||||
)
|
||||
const limit = Number.parseFloat((usage[0].currentUsageLimit as unknown as string) || '0')
|
||||
|
||||
if (!Number.isNaN(limit) && limit > 0 && currentUsage >= limit) {
|
||||
logger.info('[API VALIDATION] Usage exceeded', { userId, currentUsage, limit })
|
||||
return new NextResponse(null, { status: 402 })
|
||||
}
|
||||
}
|
||||
|
||||
return new NextResponse(null, { status: 200 })
|
||||
} catch (error) {
|
||||
logger.error('Error validating usage limit', { error })
|
||||
return NextResponse.json({ error: 'Failed to validate usage' }, { status: 500 })
|
||||
}
|
||||
}
|
||||
@@ -104,7 +104,8 @@ describe('Copilot Chat API Route', () => {
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
SIM_AGENT_API_URL: 'http://localhost:8000',
|
||||
SIM_AGENT_API_KEY: 'test-sim-agent-key',
|
||||
COPILOT_API_KEY: 'test-sim-agent-key',
|
||||
BETTER_AUTH_URL: 'http://localhost:3000',
|
||||
},
|
||||
}))
|
||||
|
||||
@@ -223,6 +224,9 @@ describe('Copilot Chat API Route', () => {
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
mode: 'agent',
|
||||
messageId: 'mock-uuid-1234-5678',
|
||||
depth: 0,
|
||||
chatId: 'chat-123',
|
||||
}),
|
||||
})
|
||||
)
|
||||
@@ -284,6 +288,9 @@ describe('Copilot Chat API Route', () => {
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
mode: 'agent',
|
||||
messageId: 'mock-uuid-1234-5678',
|
||||
depth: 0,
|
||||
chatId: 'chat-123',
|
||||
}),
|
||||
})
|
||||
)
|
||||
@@ -293,7 +300,6 @@ describe('Copilot Chat API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock new chat creation
|
||||
const newChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
@@ -302,8 +308,6 @@ describe('Copilot Chat API Route', () => {
|
||||
}
|
||||
mockReturning.mockResolvedValue([newChat])
|
||||
|
||||
// Mock sim agent response
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
body: new ReadableStream({
|
||||
@@ -337,6 +341,9 @@ describe('Copilot Chat API Route', () => {
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
mode: 'agent',
|
||||
messageId: 'mock-uuid-1234-5678',
|
||||
depth: 0,
|
||||
chatId: 'chat-123',
|
||||
}),
|
||||
})
|
||||
)
|
||||
@@ -346,11 +353,8 @@ describe('Copilot Chat API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock new chat creation
|
||||
mockReturning.mockResolvedValue([{ id: 'chat-123', messages: [] }])
|
||||
|
||||
// Mock sim agent error
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
@@ -396,11 +400,8 @@ describe('Copilot Chat API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock new chat creation
|
||||
mockReturning.mockResolvedValue([{ id: 'chat-123', messages: [] }])
|
||||
|
||||
// Mock sim agent response
|
||||
|
||||
;(global.fetch as any).mockResolvedValue({
|
||||
ok: true,
|
||||
body: new ReadableStream({
|
||||
@@ -430,6 +431,9 @@ describe('Copilot Chat API Route', () => {
|
||||
stream: true,
|
||||
streamToolCalls: true,
|
||||
mode: 'ask',
|
||||
messageId: 'mock-uuid-1234-5678',
|
||||
depth: 0,
|
||||
chatId: 'chat-123',
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
@@ -10,127 +10,69 @@ import {
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/auth'
|
||||
import { getCopilotModel } from '@/lib/copilot/config'
|
||||
import { TITLE_GENERATION_SYSTEM_PROMPT, TITLE_GENERATION_USER_PROMPT } from '@/lib/copilot/prompts'
|
||||
import type { CopilotProviderConfig } from '@/lib/copilot/types'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { downloadFile } from '@/lib/uploads'
|
||||
import { downloadFromS3WithConfig } from '@/lib/uploads/s3/s3-client'
|
||||
import { S3_COPILOT_CONFIG, USE_S3_STORAGE } from '@/lib/uploads/setup'
|
||||
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
|
||||
import { generateChatTitle } from '@/lib/sim-agent/utils'
|
||||
import { createFileContent, isSupportedFileType } from '@/lib/uploads/file-utils'
|
||||
import { S3_COPILOT_CONFIG } from '@/lib/uploads/setup'
|
||||
import { downloadFile, getStorageProvider } from '@/lib/uploads/storage-client'
|
||||
import { db } from '@/db'
|
||||
import { copilotChats } from '@/db/schema'
|
||||
import { executeProviderRequest } from '@/providers'
|
||||
import { createAnthropicFileContent, isSupportedFileType } from './file-utils'
|
||||
|
||||
const logger = createLogger('CopilotChatAPI')
|
||||
|
||||
// Schema for file attachments
|
||||
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
|
||||
|
||||
const FileAttachmentSchema = z.object({
|
||||
id: z.string(),
|
||||
s3_key: z.string(),
|
||||
key: z.string(),
|
||||
filename: z.string(),
|
||||
media_type: z.string(),
|
||||
size: z.number(),
|
||||
})
|
||||
|
||||
// Schema for chat messages
|
||||
const ChatMessageSchema = z.object({
|
||||
message: z.string().min(1, 'Message is required'),
|
||||
userMessageId: z.string().optional(), // ID from frontend for the user message
|
||||
chatId: z.string().optional(),
|
||||
workflowId: z.string().min(1, 'Workflow ID is required'),
|
||||
mode: z.enum(['ask', 'agent']).optional().default('agent'),
|
||||
depth: z.number().int().min(0).max(3).optional().default(0),
|
||||
prefetch: z.boolean().optional(),
|
||||
createNewChat: z.boolean().optional().default(false),
|
||||
stream: z.boolean().optional().default(true),
|
||||
implicitFeedback: z.string().optional(),
|
||||
fileAttachments: z.array(FileAttachmentSchema).optional(),
|
||||
})
|
||||
|
||||
// Sim Agent API configuration
|
||||
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || 'http://localhost:8000'
|
||||
const SIM_AGENT_API_KEY = env.SIM_AGENT_API_KEY
|
||||
|
||||
/**
|
||||
* Generate a chat title using LLM
|
||||
*/
|
||||
async function generateChatTitle(userMessage: string): Promise<string> {
|
||||
try {
|
||||
const { provider, model } = getCopilotModel('title')
|
||||
|
||||
// Get the appropriate API key for the provider
|
||||
let apiKey: string | undefined
|
||||
if (provider === 'anthropic') {
|
||||
// Use rotating API key for Anthropic
|
||||
const { getRotatingApiKey } = require('@/lib/utils')
|
||||
try {
|
||||
apiKey = getRotatingApiKey('anthropic')
|
||||
logger.debug(`Using rotating API key for Anthropic title generation`)
|
||||
} catch (e) {
|
||||
// If rotation fails, let the provider handle it
|
||||
logger.warn(`Failed to get rotating API key for Anthropic:`, e)
|
||||
}
|
||||
}
|
||||
|
||||
const response = await executeProviderRequest(provider, {
|
||||
model,
|
||||
systemPrompt: TITLE_GENERATION_SYSTEM_PROMPT,
|
||||
context: TITLE_GENERATION_USER_PROMPT(userMessage),
|
||||
temperature: 0.3,
|
||||
maxTokens: 50,
|
||||
apiKey: apiKey || '',
|
||||
stream: false,
|
||||
})
|
||||
|
||||
if (typeof response === 'object' && 'content' in response) {
|
||||
return response.content?.trim() || 'New Chat'
|
||||
}
|
||||
|
||||
return 'New Chat'
|
||||
} catch (error) {
|
||||
logger.error('Failed to generate chat title:', error)
|
||||
return 'New Chat'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate chat title asynchronously and update the database
|
||||
*/
|
||||
async function generateChatTitleAsync(
|
||||
chatId: string,
|
||||
userMessage: string,
|
||||
requestId: string,
|
||||
streamController?: ReadableStreamDefaultController<Uint8Array>
|
||||
): Promise<void> {
|
||||
try {
|
||||
logger.info(`[${requestId}] Starting async title generation for chat ${chatId}`)
|
||||
|
||||
const title = await generateChatTitle(userMessage)
|
||||
|
||||
// Update the chat with the generated title
|
||||
await db
|
||||
.update(copilotChats)
|
||||
.set({
|
||||
title,
|
||||
updatedAt: new Date(),
|
||||
provider: z.string().optional().default('openai'),
|
||||
conversationId: z.string().optional(),
|
||||
contexts: z
|
||||
.array(
|
||||
z.object({
|
||||
kind: z.enum([
|
||||
'past_chat',
|
||||
'workflow',
|
||||
'current_workflow',
|
||||
'blocks',
|
||||
'logs',
|
||||
'workflow_block',
|
||||
'knowledge',
|
||||
'templates',
|
||||
'docs',
|
||||
]),
|
||||
label: z.string(),
|
||||
chatId: z.string().optional(),
|
||||
workflowId: z.string().optional(),
|
||||
knowledgeId: z.string().optional(),
|
||||
blockId: z.string().optional(),
|
||||
templateId: z.string().optional(),
|
||||
executionId: z.string().optional(),
|
||||
// For workflow_block, provide both workflowId and blockId
|
||||
})
|
||||
.where(eq(copilotChats.id, chatId))
|
||||
|
||||
// Send title_updated event to client if streaming
|
||||
if (streamController) {
|
||||
const encoder = new TextEncoder()
|
||||
const titleEvent = `data: ${JSON.stringify({
|
||||
type: 'title_updated',
|
||||
title: title,
|
||||
})}\n\n`
|
||||
streamController.enqueue(encoder.encode(titleEvent))
|
||||
logger.debug(`[${requestId}] Sent title_updated event to client: "${title}"`)
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Generated title for chat ${chatId}: "${title}"`)
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Failed to generate title for chat ${chatId}:`, error)
|
||||
// Don't throw - this is a background operation
|
||||
}
|
||||
}
|
||||
)
|
||||
.optional(),
|
||||
})
|
||||
|
||||
/**
|
||||
* POST /api/copilot/chat
|
||||
@@ -156,22 +98,67 @@ export async function POST(req: NextRequest) {
|
||||
chatId,
|
||||
workflowId,
|
||||
mode,
|
||||
depth,
|
||||
prefetch,
|
||||
createNewChat,
|
||||
stream,
|
||||
implicitFeedback,
|
||||
fileAttachments,
|
||||
provider,
|
||||
conversationId,
|
||||
contexts,
|
||||
} = ChatMessageSchema.parse(body)
|
||||
// Ensure we have a consistent user message ID for this request
|
||||
const userMessageIdToUse = userMessageId || crypto.randomUUID()
|
||||
try {
|
||||
logger.info(`[${tracker.requestId}] Received chat POST`, {
|
||||
hasContexts: Array.isArray(contexts),
|
||||
contextsCount: Array.isArray(contexts) ? contexts.length : 0,
|
||||
contextsPreview: Array.isArray(contexts)
|
||||
? contexts.map((c: any) => ({
|
||||
kind: c?.kind,
|
||||
chatId: c?.chatId,
|
||||
workflowId: c?.workflowId,
|
||||
executionId: (c as any)?.executionId,
|
||||
label: c?.label,
|
||||
}))
|
||||
: undefined,
|
||||
})
|
||||
} catch {}
|
||||
// Preprocess contexts server-side
|
||||
let agentContexts: Array<{ type: string; content: string }> = []
|
||||
if (Array.isArray(contexts) && contexts.length > 0) {
|
||||
try {
|
||||
const { processContextsServer } = await import('@/lib/copilot/process-contents')
|
||||
const processed = await processContextsServer(contexts as any, authenticatedUserId, message)
|
||||
agentContexts = processed
|
||||
logger.info(`[${tracker.requestId}] Contexts processed for request`, {
|
||||
processedCount: agentContexts.length,
|
||||
kinds: agentContexts.map((c) => c.type),
|
||||
lengthPreview: agentContexts.map((c) => c.content?.length ?? 0),
|
||||
})
|
||||
if (Array.isArray(contexts) && contexts.length > 0 && agentContexts.length === 0) {
|
||||
logger.warn(
|
||||
`[${tracker.requestId}] Contexts provided but none processed. Check executionId for logs contexts.`
|
||||
)
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error(`[${tracker.requestId}] Failed to process contexts`, e)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`[${tracker.requestId}] Processing copilot chat request`, {
|
||||
userId: authenticatedUserId,
|
||||
workflowId,
|
||||
chatId,
|
||||
mode,
|
||||
stream,
|
||||
createNewChat,
|
||||
messageLength: message.length,
|
||||
hasImplicitFeedback: !!implicitFeedback,
|
||||
})
|
||||
// Consolidation mapping: map negative depths to base depth with prefetch=true
|
||||
let effectiveDepth: number | undefined = typeof depth === 'number' ? depth : undefined
|
||||
let effectivePrefetch: boolean | undefined = prefetch
|
||||
if (typeof effectiveDepth === 'number') {
|
||||
if (effectiveDepth === -2) {
|
||||
effectiveDepth = 1
|
||||
effectivePrefetch = true
|
||||
} else if (effectiveDepth === -1) {
|
||||
effectiveDepth = 0
|
||||
effectivePrefetch = true
|
||||
}
|
||||
}
|
||||
|
||||
// Handle chat context
|
||||
let currentChat: any = null
|
||||
@@ -213,8 +200,6 @@ export async function POST(req: NextRequest) {
|
||||
// Process file attachments if present
|
||||
const processedFileContents: any[] = []
|
||||
if (fileAttachments && fileAttachments.length > 0) {
|
||||
logger.info(`[${tracker.requestId}] Processing ${fileAttachments.length} file attachments`)
|
||||
|
||||
for (const attachment of fileAttachments) {
|
||||
try {
|
||||
// Check if file type is supported
|
||||
@@ -223,23 +208,30 @@ export async function POST(req: NextRequest) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Download file from S3
|
||||
logger.info(`[${tracker.requestId}] Downloading file: ${attachment.s3_key}`)
|
||||
const storageProvider = getStorageProvider()
|
||||
let fileBuffer: Buffer
|
||||
if (USE_S3_STORAGE) {
|
||||
fileBuffer = await downloadFromS3WithConfig(attachment.s3_key, S3_COPILOT_CONFIG)
|
||||
|
||||
if (storageProvider === 's3') {
|
||||
fileBuffer = await downloadFile(attachment.key, {
|
||||
bucket: S3_COPILOT_CONFIG.bucket,
|
||||
region: S3_COPILOT_CONFIG.region,
|
||||
})
|
||||
} else if (storageProvider === 'blob') {
|
||||
const { BLOB_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
|
||||
fileBuffer = await downloadFile(attachment.key, {
|
||||
containerName: BLOB_COPILOT_CONFIG.containerName,
|
||||
accountName: BLOB_COPILOT_CONFIG.accountName,
|
||||
accountKey: BLOB_COPILOT_CONFIG.accountKey,
|
||||
connectionString: BLOB_COPILOT_CONFIG.connectionString,
|
||||
})
|
||||
} else {
|
||||
// Fallback to generic downloadFile for other storage providers
|
||||
fileBuffer = await downloadFile(attachment.s3_key)
|
||||
fileBuffer = await downloadFile(attachment.key)
|
||||
}
|
||||
|
||||
// Convert to Anthropic format
|
||||
const fileContent = createAnthropicFileContent(fileBuffer, attachment.media_type)
|
||||
// Convert to format
|
||||
const fileContent = createFileContent(fileBuffer, attachment.media_type)
|
||||
if (fileContent) {
|
||||
processedFileContents.push(fileContent)
|
||||
logger.info(
|
||||
`[${tracker.requestId}] Processed file: ${attachment.filename} (${attachment.media_type})`
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
@@ -252,7 +244,7 @@ export async function POST(req: NextRequest) {
|
||||
}
|
||||
|
||||
// Build messages array for sim agent with conversation history
|
||||
const messages = []
|
||||
const messages: any[] = []
|
||||
|
||||
// Add conversation history (need to rebuild these with file support if they had attachments)
|
||||
for (const msg of conversationHistory) {
|
||||
@@ -264,14 +256,26 @@ export async function POST(req: NextRequest) {
|
||||
for (const attachment of msg.fileAttachments) {
|
||||
try {
|
||||
if (isSupportedFileType(attachment.media_type)) {
|
||||
const storageProvider = getStorageProvider()
|
||||
let fileBuffer: Buffer
|
||||
if (USE_S3_STORAGE) {
|
||||
fileBuffer = await downloadFromS3WithConfig(attachment.s3_key, S3_COPILOT_CONFIG)
|
||||
|
||||
if (storageProvider === 's3') {
|
||||
fileBuffer = await downloadFile(attachment.key, {
|
||||
bucket: S3_COPILOT_CONFIG.bucket,
|
||||
region: S3_COPILOT_CONFIG.region,
|
||||
})
|
||||
} else if (storageProvider === 'blob') {
|
||||
const { BLOB_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
|
||||
fileBuffer = await downloadFile(attachment.key, {
|
||||
containerName: BLOB_COPILOT_CONFIG.containerName,
|
||||
accountName: BLOB_COPILOT_CONFIG.accountName,
|
||||
accountKey: BLOB_COPILOT_CONFIG.accountKey,
|
||||
connectionString: BLOB_COPILOT_CONFIG.connectionString,
|
||||
})
|
||||
} else {
|
||||
// Fallback to generic downloadFile for other storage providers
|
||||
fileBuffer = await downloadFile(attachment.s3_key)
|
||||
fileBuffer = await downloadFile(attachment.key)
|
||||
}
|
||||
const fileContent = createAnthropicFileContent(fileBuffer, attachment.media_type)
|
||||
const fileContent = createFileContent(fileBuffer, attachment.media_type)
|
||||
if (fileContent) {
|
||||
content.push(fileContent)
|
||||
}
|
||||
@@ -327,40 +331,83 @@ export async function POST(req: NextRequest) {
|
||||
})
|
||||
}
|
||||
|
||||
// Start title generation in parallel if this is a new chat with first message
|
||||
if (actualChatId && !currentChat?.title && conversationHistory.length === 0) {
|
||||
logger.info(`[${tracker.requestId}] Will start parallel title generation inside stream`)
|
||||
const defaults = getCopilotModel('chat')
|
||||
const modelToUse = env.COPILOT_MODEL || defaults.model
|
||||
|
||||
let providerConfig: CopilotProviderConfig | undefined
|
||||
const providerEnv = env.COPILOT_PROVIDER as any
|
||||
|
||||
if (providerEnv) {
|
||||
if (providerEnv === 'azure-openai') {
|
||||
providerConfig = {
|
||||
provider: 'azure-openai',
|
||||
model: modelToUse,
|
||||
apiKey: env.AZURE_OPENAI_API_KEY,
|
||||
apiVersion: 'preview',
|
||||
endpoint: env.AZURE_OPENAI_ENDPOINT,
|
||||
}
|
||||
} else {
|
||||
providerConfig = {
|
||||
provider: providerEnv,
|
||||
model: modelToUse,
|
||||
apiKey: env.COPILOT_API_KEY,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Forward to sim agent API
|
||||
logger.info(`[${tracker.requestId}] Sending request to sim agent API`, {
|
||||
messageCount: messages.length,
|
||||
endpoint: `${SIM_AGENT_API_URL}/api/chat-completion-streaming`,
|
||||
})
|
||||
// Determine provider and conversationId to use for this request
|
||||
const effectiveConversationId =
|
||||
(currentChat?.conversationId as string | undefined) || conversationId
|
||||
|
||||
// If we have a conversationId, only send the most recent user message; else send full history
|
||||
const latestUserMessage =
|
||||
[...messages].reverse().find((m) => m?.role === 'user') || messages[messages.length - 1]
|
||||
const messagesForAgent = effectiveConversationId ? [latestUserMessage] : messages
|
||||
|
||||
const requestPayload = {
|
||||
messages: messagesForAgent,
|
||||
workflowId,
|
||||
userId: authenticatedUserId,
|
||||
stream: stream,
|
||||
streamToolCalls: true,
|
||||
mode: mode,
|
||||
messageId: userMessageIdToUse,
|
||||
...(providerConfig ? { provider: providerConfig } : {}),
|
||||
...(effectiveConversationId ? { conversationId: effectiveConversationId } : {}),
|
||||
...(typeof effectiveDepth === 'number' ? { depth: effectiveDepth } : {}),
|
||||
...(typeof effectivePrefetch === 'boolean' ? { prefetch: effectivePrefetch } : {}),
|
||||
...(session?.user?.name && { userName: session.user.name }),
|
||||
...(agentContexts.length > 0 && { context: agentContexts }),
|
||||
...(actualChatId ? { chatId: actualChatId } : {}),
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info(`[${tracker.requestId}] About to call Sim Agent with context`, {
|
||||
context: (requestPayload as any).context,
|
||||
})
|
||||
} catch {}
|
||||
|
||||
const simAgentResponse = await fetch(`${SIM_AGENT_API_URL}/api/chat-completion-streaming`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(SIM_AGENT_API_KEY && { 'x-api-key': SIM_AGENT_API_KEY }),
|
||||
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
|
||||
},
|
||||
body: JSON.stringify({
|
||||
messages,
|
||||
workflowId,
|
||||
userId: authenticatedUserId,
|
||||
stream: stream,
|
||||
streamToolCalls: true,
|
||||
mode: mode,
|
||||
...(session?.user?.name && { userName: session.user.name }),
|
||||
}),
|
||||
body: JSON.stringify(requestPayload),
|
||||
})
|
||||
|
||||
if (!simAgentResponse.ok) {
|
||||
const errorText = await simAgentResponse.text()
|
||||
if (simAgentResponse.status === 401 || simAgentResponse.status === 402) {
|
||||
// Rethrow status only; client will render appropriate assistant message
|
||||
return new NextResponse(null, { status: simAgentResponse.status })
|
||||
}
|
||||
|
||||
const errorText = await simAgentResponse.text().catch(() => '')
|
||||
logger.error(`[${tracker.requestId}] Sim agent API error:`, {
|
||||
status: simAgentResponse.status,
|
||||
error: errorText,
|
||||
})
|
||||
|
||||
return NextResponse.json(
|
||||
{ error: `Sim agent API error: ${simAgentResponse.statusText}` },
|
||||
{ status: simAgentResponse.status }
|
||||
@@ -369,15 +416,18 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
// If streaming is requested, forward the stream and update chat later
|
||||
if (stream && simAgentResponse.body) {
|
||||
logger.info(`[${tracker.requestId}] Streaming response from sim agent`)
|
||||
|
||||
// Create user message to save
|
||||
const userMessage = {
|
||||
id: userMessageId || crypto.randomUUID(), // Use frontend ID if provided
|
||||
id: userMessageIdToUse, // Consistent ID used for request and persistence
|
||||
role: 'user',
|
||||
content: message,
|
||||
timestamp: new Date().toISOString(),
|
||||
...(fileAttachments && fileAttachments.length > 0 && { fileAttachments }),
|
||||
...(Array.isArray(contexts) && contexts.length > 0 && { contexts }),
|
||||
...(Array.isArray(contexts) &&
|
||||
contexts.length > 0 && {
|
||||
contentBlocks: [{ type: 'contexts', contexts: contexts as any, timestamp: Date.now() }],
|
||||
}),
|
||||
}
|
||||
|
||||
// Create a pass-through stream that captures the response
|
||||
@@ -387,7 +437,15 @@ export async function POST(req: NextRequest) {
|
||||
let assistantContent = ''
|
||||
const toolCalls: any[] = []
|
||||
let buffer = ''
|
||||
let isFirstDone = true
|
||||
const isFirstDone = true
|
||||
let responseIdFromStart: string | undefined
|
||||
let responseIdFromDone: string | undefined
|
||||
// Track tool call progress to identify a safe done event
|
||||
const announcedToolCallIds = new Set<string>()
|
||||
const startedToolExecutionIds = new Set<string>()
|
||||
const completedToolExecutionIds = new Set<string>()
|
||||
let lastDoneResponseId: string | undefined
|
||||
let lastSafeDoneResponseId: string | undefined
|
||||
|
||||
// Send chatId as first event
|
||||
if (actualChatId) {
|
||||
@@ -401,30 +459,30 @@ export async function POST(req: NextRequest) {
|
||||
|
||||
// Start title generation in parallel if needed
|
||||
if (actualChatId && !currentChat?.title && conversationHistory.length === 0) {
|
||||
logger.info(`[${tracker.requestId}] Starting title generation with stream updates`, {
|
||||
chatId: actualChatId,
|
||||
hasTitle: !!currentChat?.title,
|
||||
conversationLength: conversationHistory.length,
|
||||
message: message.substring(0, 100) + (message.length > 100 ? '...' : ''),
|
||||
})
|
||||
generateChatTitleAsync(actualChatId, message, tracker.requestId, controller).catch(
|
||||
(error) => {
|
||||
generateChatTitle(message)
|
||||
.then(async (title) => {
|
||||
if (title) {
|
||||
await db
|
||||
.update(copilotChats)
|
||||
.set({
|
||||
title,
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(eq(copilotChats.id, actualChatId!))
|
||||
|
||||
const titleEvent = `data: ${JSON.stringify({
|
||||
type: 'title_updated',
|
||||
title: title,
|
||||
})}\n\n`
|
||||
controller.enqueue(encoder.encode(titleEvent))
|
||||
logger.info(`[${tracker.requestId}] Generated and saved title: ${title}`)
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
|
||||
}
|
||||
)
|
||||
})
|
||||
} else {
|
||||
logger.debug(`[${tracker.requestId}] Skipping title generation`, {
|
||||
chatId: actualChatId,
|
||||
hasTitle: !!currentChat?.title,
|
||||
conversationLength: conversationHistory.length,
|
||||
reason: !actualChatId
|
||||
? 'no chatId'
|
||||
: currentChat?.title
|
||||
? 'already has title'
|
||||
: conversationHistory.length > 0
|
||||
? 'not first message'
|
||||
: 'unknown',
|
||||
})
|
||||
logger.debug(`[${tracker.requestId}] Skipping title generation`)
|
||||
}
|
||||
|
||||
// Forward the sim agent stream and capture assistant response
|
||||
@@ -435,24 +493,9 @@ export async function POST(req: NextRequest) {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
logger.info(`[${tracker.requestId}] Stream reading completed`)
|
||||
break
|
||||
}
|
||||
|
||||
// Check if client disconnected before processing chunk
|
||||
try {
|
||||
// Forward the chunk to client immediately
|
||||
controller.enqueue(value)
|
||||
} catch (error) {
|
||||
// Client disconnected - stop reading from sim agent
|
||||
logger.info(
|
||||
`[${tracker.requestId}] Client disconnected, stopping stream processing`
|
||||
)
|
||||
reader.cancel() // Stop reading from sim agent
|
||||
break
|
||||
}
|
||||
const chunkSize = value.byteLength
|
||||
|
||||
// Decode and parse SSE events for logging and capturing content
|
||||
const decodedChunk = decoder.decode(value, { stream: true })
|
||||
buffer += decodedChunk
|
||||
@@ -486,37 +529,31 @@ export async function POST(req: NextRequest) {
|
||||
}
|
||||
break
|
||||
|
||||
case 'tool_call':
|
||||
logger.info(
|
||||
`[${tracker.requestId}] Tool call ${event.data?.partial ? '(partial)' : '(complete)'}:`,
|
||||
{
|
||||
id: event.data?.id,
|
||||
name: event.data?.name,
|
||||
arguments: event.data?.arguments,
|
||||
blockIndex: event.data?._blockIndex,
|
||||
}
|
||||
case 'reasoning':
|
||||
logger.debug(
|
||||
`[${tracker.requestId}] Reasoning chunk received (${(event.data || event.content || '').length} chars)`
|
||||
)
|
||||
break
|
||||
|
||||
case 'tool_call':
|
||||
if (!event.data?.partial) {
|
||||
toolCalls.push(event.data)
|
||||
if (event.data?.id) {
|
||||
announcedToolCallIds.add(event.data.id)
|
||||
}
|
||||
}
|
||||
break
|
||||
|
||||
case 'tool_execution':
|
||||
logger.info(`[${tracker.requestId}] Tool execution started:`, {
|
||||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
status: event.status,
|
||||
})
|
||||
case 'tool_generating':
|
||||
if (event.toolCallId) {
|
||||
startedToolExecutionIds.add(event.toolCallId)
|
||||
}
|
||||
break
|
||||
|
||||
case 'tool_result':
|
||||
logger.info(`[${tracker.requestId}] Tool result received:`, {
|
||||
toolCallId: event.toolCallId,
|
||||
toolName: event.toolName,
|
||||
success: event.success,
|
||||
result: `${JSON.stringify(event.result).substring(0, 200)}...`,
|
||||
resultSize: JSON.stringify(event.result).length,
|
||||
})
|
||||
if (event.toolCallId) {
|
||||
completedToolExecutionIds.add(event.toolCallId)
|
||||
}
|
||||
break
|
||||
|
||||
case 'tool_error':
|
||||
@@ -526,28 +563,78 @@ export async function POST(req: NextRequest) {
|
||||
error: event.error,
|
||||
success: event.success,
|
||||
})
|
||||
if (event.toolCallId) {
|
||||
completedToolExecutionIds.add(event.toolCallId)
|
||||
}
|
||||
break
|
||||
|
||||
case 'start':
|
||||
if (event.data?.responseId) {
|
||||
responseIdFromStart = event.data.responseId
|
||||
}
|
||||
break
|
||||
|
||||
case 'done':
|
||||
if (isFirstDone) {
|
||||
logger.info(
|
||||
`[${tracker.requestId}] Initial AI response complete, tool count: ${toolCalls.length}`
|
||||
)
|
||||
isFirstDone = false
|
||||
} else {
|
||||
logger.info(`[${tracker.requestId}] Conversation round complete`)
|
||||
if (event.data?.responseId) {
|
||||
responseIdFromDone = event.data.responseId
|
||||
lastDoneResponseId = responseIdFromDone
|
||||
|
||||
// Mark this done as safe only if no tool call is currently in progress or pending
|
||||
const announced = announcedToolCallIds.size
|
||||
const completed = completedToolExecutionIds.size
|
||||
const started = startedToolExecutionIds.size
|
||||
const hasToolInProgress = announced > completed || started > completed
|
||||
if (!hasToolInProgress) {
|
||||
lastSafeDoneResponseId = responseIdFromDone
|
||||
}
|
||||
}
|
||||
break
|
||||
|
||||
case 'error':
|
||||
logger.error(`[${tracker.requestId}] Stream error event:`, event.error)
|
||||
break
|
||||
|
||||
default:
|
||||
logger.debug(
|
||||
`[${tracker.requestId}] Unknown event type: ${event.type}`,
|
||||
event
|
||||
)
|
||||
}
|
||||
|
||||
// Emit to client: rewrite 'error' events into user-friendly assistant message
|
||||
if (event?.type === 'error') {
|
||||
try {
|
||||
const displayMessage: string =
|
||||
(event?.data && (event.data.displayMessage as string)) ||
|
||||
'Sorry, I encountered an error. Please try again.'
|
||||
const formatted = `_${displayMessage}_`
|
||||
// Accumulate so it persists to DB as assistant content
|
||||
assistantContent += formatted
|
||||
// Send as content chunk
|
||||
try {
|
||||
controller.enqueue(
|
||||
encoder.encode(
|
||||
`data: ${JSON.stringify({ type: 'content', data: formatted })}\n\n`
|
||||
)
|
||||
)
|
||||
} catch (enqueueErr) {
|
||||
reader.cancel()
|
||||
break
|
||||
}
|
||||
// Then close this response cleanly for the client
|
||||
try {
|
||||
controller.enqueue(
|
||||
encoder.encode(`data: ${JSON.stringify({ type: 'done' })}\n\n`)
|
||||
)
|
||||
} catch (enqueueErr) {
|
||||
reader.cancel()
|
||||
break
|
||||
}
|
||||
} catch {}
|
||||
// Do not forward the original error event
|
||||
} else {
|
||||
// Forward original event to client
|
||||
try {
|
||||
controller.enqueue(encoder.encode(`data: ${jsonStr}\n\n`))
|
||||
} catch (enqueueErr) {
|
||||
reader.cancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
// Enhanced error handling for large payloads and parsing issues
|
||||
@@ -581,10 +668,37 @@ export async function POST(req: NextRequest) {
|
||||
logger.debug(`[${tracker.requestId}] Processing remaining buffer: "${buffer}"`)
|
||||
if (buffer.startsWith('data: ')) {
|
||||
try {
|
||||
const event = JSON.parse(buffer.slice(6))
|
||||
const jsonStr = buffer.slice(6)
|
||||
const event = JSON.parse(jsonStr)
|
||||
if (event.type === 'content' && event.data) {
|
||||
assistantContent += event.data
|
||||
}
|
||||
// Forward remaining event, applying same error rewrite behavior
|
||||
if (event?.type === 'error') {
|
||||
const displayMessage: string =
|
||||
(event?.data && (event.data.displayMessage as string)) ||
|
||||
'Sorry, I encountered an error. Please try again.'
|
||||
const formatted = `_${displayMessage}_`
|
||||
assistantContent += formatted
|
||||
try {
|
||||
controller.enqueue(
|
||||
encoder.encode(
|
||||
`data: ${JSON.stringify({ type: 'content', data: formatted })}\n\n`
|
||||
)
|
||||
)
|
||||
controller.enqueue(
|
||||
encoder.encode(`data: ${JSON.stringify({ type: 'done' })}\n\n`)
|
||||
)
|
||||
} catch (enqueueErr) {
|
||||
reader.cancel()
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
controller.enqueue(encoder.encode(`data: ${jsonStr}\n\n`))
|
||||
} catch (enqueueErr) {
|
||||
reader.cancel()
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.warn(`[${tracker.requestId}] Failed to parse final buffer: "${buffer}"`)
|
||||
}
|
||||
@@ -622,12 +736,17 @@ export async function POST(req: NextRequest) {
|
||||
)
|
||||
}
|
||||
|
||||
// Persist only a safe conversationId to avoid continuing from a state that expects tool outputs
|
||||
const previousConversationId = currentChat?.conversationId as string | undefined
|
||||
const responseId = lastSafeDoneResponseId || previousConversationId || undefined
|
||||
|
||||
// Update chat in database immediately (without title)
|
||||
await db
|
||||
.update(copilotChats)
|
||||
.set({
|
||||
messages: updatedMessages,
|
||||
updatedAt: new Date(),
|
||||
...(responseId ? { conversationId: responseId } : {}),
|
||||
})
|
||||
.where(eq(copilotChats.id, actualChatId!))
|
||||
|
||||
@@ -635,6 +754,7 @@ export async function POST(req: NextRequest) {
|
||||
messageCount: updatedMessages.length,
|
||||
savedUserMessage: true,
|
||||
savedAssistantMessage: assistantContent.trim().length > 0,
|
||||
updatedConversationId: responseId || null,
|
||||
})
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -694,11 +814,16 @@ export async function POST(req: NextRequest) {
|
||||
// Save messages if we have a chat
|
||||
if (currentChat && responseData.content) {
|
||||
const userMessage = {
|
||||
id: userMessageId || crypto.randomUUID(), // Use frontend ID if provided
|
||||
id: userMessageIdToUse, // Consistent ID used for request and persistence
|
||||
role: 'user',
|
||||
content: message,
|
||||
timestamp: new Date().toISOString(),
|
||||
...(fileAttachments && fileAttachments.length > 0 && { fileAttachments }),
|
||||
...(Array.isArray(contexts) && contexts.length > 0 && { contexts }),
|
||||
...(Array.isArray(contexts) &&
|
||||
contexts.length > 0 && {
|
||||
contentBlocks: [{ type: 'contexts', contexts: contexts as any, timestamp: Date.now() }],
|
||||
}),
|
||||
}
|
||||
|
||||
const assistantMessage = {
|
||||
@@ -713,9 +838,22 @@ export async function POST(req: NextRequest) {
|
||||
// Start title generation in parallel if this is first message (non-streaming)
|
||||
if (actualChatId && !currentChat.title && conversationHistory.length === 0) {
|
||||
logger.info(`[${tracker.requestId}] Starting title generation for non-streaming response`)
|
||||
generateChatTitleAsync(actualChatId, message, tracker.requestId).catch((error) => {
|
||||
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
|
||||
})
|
||||
generateChatTitle(message)
|
||||
.then(async (title) => {
|
||||
if (title) {
|
||||
await db
|
||||
.update(copilotChats)
|
||||
.set({
|
||||
title,
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(eq(copilotChats.id, actualChatId!))
|
||||
logger.info(`[${tracker.requestId}] Generated and saved title: ${title}`)
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
logger.error(`[${tracker.requestId}] Title generation failed:`, error)
|
||||
})
|
||||
}
|
||||
|
||||
// Update chat in database immediately (without blocking for title)
|
||||
|
||||
@@ -229,7 +229,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists - override the default empty array
|
||||
const existingChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
@@ -267,7 +266,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
messageCount: 2,
|
||||
})
|
||||
|
||||
// Verify database operations
|
||||
expect(mockSelect).toHaveBeenCalled()
|
||||
expect(mockUpdate).toHaveBeenCalled()
|
||||
expect(mockSet).toHaveBeenCalledWith({
|
||||
@@ -280,7 +278,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const existingChat = {
|
||||
id: 'chat-456',
|
||||
userId: 'user-123',
|
||||
@@ -341,7 +338,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const existingChat = {
|
||||
id: 'chat-789',
|
||||
userId: 'user-123',
|
||||
@@ -374,7 +370,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock database error during chat lookup
|
||||
mockLimit.mockRejectedValueOnce(new Error('Database connection failed'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
@@ -401,7 +396,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const existingChat = {
|
||||
id: 'chat-123',
|
||||
userId: 'user-123',
|
||||
@@ -409,7 +403,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
}
|
||||
mockLimit.mockResolvedValueOnce([existingChat])
|
||||
|
||||
// Mock database error during update
|
||||
mockSet.mockReturnValueOnce({
|
||||
where: vi.fn().mockRejectedValue(new Error('Update operation failed')),
|
||||
})
|
||||
@@ -438,7 +431,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Create a request with invalid JSON
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/chat/update-messages', {
|
||||
method: 'POST',
|
||||
body: '{invalid-json',
|
||||
@@ -459,7 +451,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const existingChat = {
|
||||
id: 'chat-large',
|
||||
userId: 'user-123',
|
||||
@@ -467,7 +458,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
}
|
||||
mockLimit.mockResolvedValueOnce([existingChat])
|
||||
|
||||
// Create a large array of messages
|
||||
const messages = Array.from({ length: 100 }, (_, i) => ({
|
||||
id: `msg-${i + 1}`,
|
||||
role: i % 2 === 0 ? 'user' : 'assistant',
|
||||
@@ -500,7 +490,6 @@ describe('Copilot Chat Update Messages API Route', () => {
|
||||
const authMocks = mockAuth()
|
||||
authMocks.setAuthenticated()
|
||||
|
||||
// Mock chat exists
|
||||
const existingChat = {
|
||||
id: 'chat-mixed',
|
||||
userId: 'user-123',
|
||||
|
||||
@@ -28,7 +28,7 @@ const UpdateMessagesSchema = z.object({
|
||||
.array(
|
||||
z.object({
|
||||
id: z.string(),
|
||||
s3_key: z.string(),
|
||||
key: z.string(),
|
||||
filename: z.string(),
|
||||
media_type: z.string(),
|
||||
size: z.number(),
|
||||
@@ -51,12 +51,6 @@ export async function POST(req: NextRequest) {
|
||||
const body = await req.json()
|
||||
const { chatId, messages } = UpdateMessagesSchema.parse(body)
|
||||
|
||||
logger.info(`[${tracker.requestId}] Updating chat messages`, {
|
||||
userId,
|
||||
chatId,
|
||||
messageCount: messages.length,
|
||||
})
|
||||
|
||||
// Verify that the chat belongs to the user
|
||||
const [chat] = await db
|
||||
.select()
|
||||
|
||||
39
apps/sim/app/api/copilot/chats/route.ts
Normal file
39
apps/sim/app/api/copilot/chats/route.ts
Normal file
@@ -0,0 +1,39 @@
|
||||
import { desc, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import {
|
||||
authenticateCopilotRequestSessionOnly,
|
||||
createInternalServerErrorResponse,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { copilotChats } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('CopilotChatsListAPI')
|
||||
|
||||
export async function GET(_req: NextRequest) {
|
||||
try {
|
||||
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
|
||||
if (!isAuthenticated || !userId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const chats = await db
|
||||
.select({
|
||||
id: copilotChats.id,
|
||||
title: copilotChats.title,
|
||||
workflowId: copilotChats.workflowId,
|
||||
updatedAt: copilotChats.updatedAt,
|
||||
})
|
||||
.from(copilotChats)
|
||||
.where(eq(copilotChats.userId, userId))
|
||||
.orderBy(desc(copilotChats.updatedAt))
|
||||
|
||||
logger.info(`Retrieved ${chats.length} chats for user ${userId}`)
|
||||
|
||||
return NextResponse.json({ success: true, chats })
|
||||
} catch (error) {
|
||||
logger.error('Error fetching user copilot chats:', error)
|
||||
return createInternalServerErrorResponse('Failed to fetch user chats')
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,7 @@ async function updateToolCallStatus(
|
||||
|
||||
try {
|
||||
const key = `tool_call:${toolCallId}`
|
||||
const timeout = 60000 // 1 minute timeout
|
||||
const timeout = 600000 // 10 minutes timeout for user confirmation
|
||||
const pollInterval = 100 // Poll every 100ms
|
||||
const startTime = Date.now()
|
||||
|
||||
@@ -48,11 +48,6 @@ async function updateToolCallStatus(
|
||||
while (Date.now() - startTime < timeout) {
|
||||
const exists = await redis.exists(key)
|
||||
if (exists) {
|
||||
logger.info('Tool call found in Redis, updating status', {
|
||||
toolCallId,
|
||||
key,
|
||||
pollDuration: Date.now() - startTime,
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
@@ -79,27 +74,8 @@ async function updateToolCallStatus(
|
||||
timestamp: new Date().toISOString(),
|
||||
}
|
||||
|
||||
// Log what we're about to update in Redis
|
||||
logger.info('About to update Redis with tool call data', {
|
||||
toolCallId,
|
||||
key,
|
||||
toolCallData,
|
||||
serializedData: JSON.stringify(toolCallData),
|
||||
providedStatus: status,
|
||||
providedMessage: message,
|
||||
messageIsUndefined: message === undefined,
|
||||
messageIsNull: message === null,
|
||||
})
|
||||
|
||||
await redis.set(key, JSON.stringify(toolCallData), 'EX', 86400) // Keep 24 hour expiry
|
||||
|
||||
logger.info('Tool call status updated in Redis', {
|
||||
toolCallId,
|
||||
key,
|
||||
status,
|
||||
message,
|
||||
pollDuration: Date.now() - startTime,
|
||||
})
|
||||
return true
|
||||
} catch (error) {
|
||||
logger.error('Failed to update tool call status in Redis', {
|
||||
@@ -131,13 +107,6 @@ export async function POST(req: NextRequest) {
|
||||
const body = await req.json()
|
||||
const { toolCallId, status, message } = ConfirmationSchema.parse(body)
|
||||
|
||||
logger.info(`[${tracker.requestId}] Tool call confirmation request`, {
|
||||
userId: authenticatedUserId,
|
||||
toolCallId,
|
||||
status,
|
||||
message,
|
||||
})
|
||||
|
||||
// Update the tool call status in Redis
|
||||
const updated = await updateToolCallStatus(toolCallId, status, message)
|
||||
|
||||
@@ -153,13 +122,6 @@ export async function POST(req: NextRequest) {
|
||||
}
|
||||
|
||||
const duration = tracker.getDuration()
|
||||
logger.info(`[${tracker.requestId}] Tool call confirmation completed`, {
|
||||
userId: authenticatedUserId,
|
||||
toolCallId,
|
||||
status,
|
||||
internalStatus: status,
|
||||
duration,
|
||||
})
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import {
|
||||
authenticateCopilotRequestSessionOnly,
|
||||
createBadRequestResponse,
|
||||
createInternalServerErrorResponse,
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/auth'
|
||||
import { routeExecution } from '@/lib/copilot/tools/server/router'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('ExecuteCopilotServerToolAPI')
|
||||
|
||||
const ExecuteSchema = z.object({
|
||||
toolName: z.string(),
|
||||
payload: z.unknown().optional(),
|
||||
})
|
||||
|
||||
export async function POST(req: NextRequest) {
|
||||
const tracker = createRequestTracker()
|
||||
try {
|
||||
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
|
||||
if (!isAuthenticated || !userId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const body = await req.json()
|
||||
try {
|
||||
const preview = JSON.stringify(body).slice(0, 300)
|
||||
logger.debug(`[${tracker.requestId}] Incoming request body preview`, { preview })
|
||||
} catch {}
|
||||
|
||||
const { toolName, payload } = ExecuteSchema.parse(body)
|
||||
|
||||
logger.info(`[${tracker.requestId}] Executing server tool`, { toolName })
|
||||
const result = await routeExecution(toolName, payload)
|
||||
|
||||
try {
|
||||
const resultPreview = JSON.stringify(result).slice(0, 300)
|
||||
logger.debug(`[${tracker.requestId}] Server tool result preview`, { toolName, resultPreview })
|
||||
} catch {}
|
||||
|
||||
return NextResponse.json({ success: true, result })
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.debug(`[${tracker.requestId}] Zod validation error`, { issues: error.issues })
|
||||
return createBadRequestResponse('Invalid request body for execute-copilot-server-tool')
|
||||
}
|
||||
logger.error(`[${tracker.requestId}] Failed to execute server tool:`, error)
|
||||
return createInternalServerErrorResponse('Failed to execute server tool')
|
||||
}
|
||||
}
|
||||
@@ -1,762 +1,7 @@
|
||||
/**
|
||||
* Tests for copilot methods API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { NextRequest } from 'next/server'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createMockRequest,
|
||||
mockCryptoUuid,
|
||||
setupCommonApiMocks,
|
||||
} from '@/app/api/__test-utils__/utils'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
describe('Copilot Methods API Route', () => {
|
||||
const mockRedisGet = vi.fn()
|
||||
const mockRedisSet = vi.fn()
|
||||
const mockGetRedisClient = vi.fn()
|
||||
const mockToolRegistryHas = vi.fn()
|
||||
const mockToolRegistryGet = vi.fn()
|
||||
const mockToolRegistryExecute = vi.fn()
|
||||
const mockToolRegistryGetAvailableIds = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
setupCommonApiMocks()
|
||||
mockCryptoUuid()
|
||||
|
||||
// Mock Redis client
|
||||
const mockRedisClient = {
|
||||
get: mockRedisGet,
|
||||
set: mockRedisSet,
|
||||
}
|
||||
|
||||
mockGetRedisClient.mockReturnValue(mockRedisClient)
|
||||
mockRedisGet.mockResolvedValue(null)
|
||||
mockRedisSet.mockResolvedValue('OK')
|
||||
|
||||
vi.doMock('@/lib/redis', () => ({
|
||||
getRedisClient: mockGetRedisClient,
|
||||
}))
|
||||
|
||||
// Mock tool registry
|
||||
const mockToolRegistry = {
|
||||
has: mockToolRegistryHas,
|
||||
get: mockToolRegistryGet,
|
||||
execute: mockToolRegistryExecute,
|
||||
getAvailableIds: mockToolRegistryGetAvailableIds,
|
||||
}
|
||||
|
||||
mockToolRegistryHas.mockReturnValue(true)
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: false })
|
||||
mockToolRegistryExecute.mockResolvedValue({ success: true, data: 'Tool executed successfully' })
|
||||
mockToolRegistryGetAvailableIds.mockReturnValue(['test-tool', 'another-tool'])
|
||||
|
||||
vi.doMock('@/lib/copilot/tools/server-tools/registry', () => ({
|
||||
copilotToolRegistry: mockToolRegistry,
|
||||
}))
|
||||
|
||||
// Mock environment variables
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
INTERNAL_API_SECRET: 'test-secret-key',
|
||||
},
|
||||
}))
|
||||
|
||||
// Mock setTimeout for polling
|
||||
vi.spyOn(global, 'setTimeout').mockImplementation((callback, _delay) => {
|
||||
if (typeof callback === 'function') {
|
||||
setImmediate(callback)
|
||||
}
|
||||
return setTimeout(() => {}, 0) as any
|
||||
})
|
||||
|
||||
// Mock Date.now for timeout control
|
||||
let mockTime = 1640995200000
|
||||
vi.spyOn(Date, 'now').mockImplementation(() => {
|
||||
mockTime += 1000 // Add 1 second each call
|
||||
return mockTime
|
||||
})
|
||||
|
||||
// Mock crypto.randomUUID for request IDs
|
||||
vi.spyOn(crypto, 'randomUUID').mockReturnValue('test-request-id')
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('POST', () => {
|
||||
it('should return 401 when API key is missing', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
methodId: 'test-tool',
|
||||
params: {},
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: false,
|
||||
error: 'API key required',
|
||||
})
|
||||
})
|
||||
|
||||
it('should return 401 when API key is invalid', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'invalid-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'test-tool',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: false,
|
||||
error: 'Invalid API key',
|
||||
})
|
||||
})
|
||||
|
||||
it('should return 401 when internal API key is not configured', async () => {
|
||||
// Mock environment with no API key
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
INTERNAL_API_SECRET: undefined,
|
||||
},
|
||||
}))
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'any-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'test-tool',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: false,
|
||||
error: 'Internal API key not configured',
|
||||
})
|
||||
})
|
||||
|
||||
it('should return 400 for invalid request body - missing methodId', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
params: {},
|
||||
// Missing methodId
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toContain('Required')
|
||||
})
|
||||
|
||||
it('should return 400 for empty methodId', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: '',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toContain('Method ID is required')
|
||||
})
|
||||
|
||||
it('should return 400 when tool is not found in registry', async () => {
|
||||
mockToolRegistryHas.mockReturnValue(false)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'unknown-tool',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toContain('Unknown method: unknown-tool')
|
||||
expect(responseData.error).toContain('Available methods: test-tool, another-tool')
|
||||
})
|
||||
|
||||
it('should successfully execute a tool without interruption', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'test-tool',
|
||||
params: { key: 'value' },
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalledWith('test-tool', { key: 'value' })
|
||||
})
|
||||
|
||||
it('should handle tool execution with default empty params', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'test-tool',
|
||||
// No params provided
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalledWith('test-tool', {})
|
||||
})
|
||||
|
||||
it('should return 400 when tool requires interrupt but no toolCallId provided', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
// No toolCallId provided
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe(
|
||||
'This tool requires approval but no tool call ID was provided'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - user approval', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return accepted status immediately (simulate quick approval)
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'accepted', message: 'User approved' })
|
||||
)
|
||||
|
||||
// Reset Date.now mock to not trigger timeout
|
||||
let mockTime = 1640995200000
|
||||
vi.spyOn(Date, 'now').mockImplementation(() => {
|
||||
mockTime += 100 // Small increment to avoid timeout
|
||||
return mockTime
|
||||
})
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: { key: 'value' },
|
||||
toolCallId: 'tool-call-123',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
// Verify Redis operations
|
||||
expect(mockRedisSet).toHaveBeenCalledWith(
|
||||
'tool_call:tool-call-123',
|
||||
expect.stringContaining('"status":"pending"'),
|
||||
'EX',
|
||||
86400
|
||||
)
|
||||
expect(mockRedisGet).toHaveBeenCalledWith('tool_call:tool-call-123')
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalledWith('interrupt-tool', {
|
||||
key: 'value',
|
||||
confirmationMessage: 'User approved',
|
||||
fullData: {
|
||||
message: 'User approved',
|
||||
status: 'accepted',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - user rejection', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return rejected status
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'rejected', message: 'User rejected' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-456',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200) // User rejection returns 200
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe(
|
||||
'The user decided to skip running this tool. This was a user decision.'
|
||||
)
|
||||
|
||||
// Tool should not be executed when rejected
|
||||
expect(mockToolRegistryExecute).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - error status', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return error status
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'error', message: 'Tool execution failed' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-error',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Tool execution failed')
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - background status', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return background status
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'background', message: 'Running in background' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-bg',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - success status', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return success status
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'success', message: 'Completed successfully' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-success',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle tool execution with interrupt - timeout', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to never return a status (timeout scenario)
|
||||
mockRedisGet.mockResolvedValue(null)
|
||||
|
||||
// Mock Date.now to trigger timeout quickly
|
||||
let mockTime = 1640995200000
|
||||
vi.spyOn(Date, 'now').mockImplementation(() => {
|
||||
mockTime += 100000 // Add 100 seconds each call to trigger timeout
|
||||
return mockTime
|
||||
})
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-timeout',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(408) // Request Timeout
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Tool execution request timed out')
|
||||
|
||||
expect(mockToolRegistryExecute).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle unexpected status in interrupt flow', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return unexpected status
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'unknown-status', message: 'Unknown' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-unknown',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Unexpected tool call status: unknown-status')
|
||||
})
|
||||
|
||||
it('should handle Redis client unavailable for interrupt flow', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
mockGetRedisClient.mockReturnValue(null)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-no-redis',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(408) // Timeout due to Redis unavailable
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Tool execution request timed out')
|
||||
})
|
||||
|
||||
it('should handle no_op tool with confirmation message', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return accepted status with message
|
||||
mockRedisGet.mockResolvedValue(
|
||||
JSON.stringify({ status: 'accepted', message: 'Confirmation message' })
|
||||
)
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'no_op',
|
||||
params: { existing: 'param' },
|
||||
toolCallId: 'tool-call-noop',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
|
||||
// Verify confirmation message was added to params
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalledWith('no_op', {
|
||||
existing: 'param',
|
||||
confirmationMessage: 'Confirmation message',
|
||||
fullData: {
|
||||
message: 'Confirmation message',
|
||||
status: 'accepted',
|
||||
},
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle Redis errors in interrupt flow', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to throw an error
|
||||
mockRedisGet.mockRejectedValue(new Error('Redis connection failed'))
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-redis-error',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(408) // Timeout due to Redis error
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Tool execution request timed out')
|
||||
})
|
||||
|
||||
it('should handle tool execution failure', async () => {
|
||||
mockToolRegistryExecute.mockResolvedValue({
|
||||
success: false,
|
||||
error: 'Tool execution failed',
|
||||
})
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'failing-tool',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200) // Still returns 200, but with success: false
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: false,
|
||||
error: 'Tool execution failed',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle JSON parsing errors in request body', async () => {
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: '{invalid-json',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toContain('JSON')
|
||||
})
|
||||
|
||||
it('should handle tool registry execution throwing an error', async () => {
|
||||
mockToolRegistryExecute.mockRejectedValue(new Error('Registry execution failed'))
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'error-tool',
|
||||
params: {},
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const responseData = await response.json()
|
||||
expect(responseData.success).toBe(false)
|
||||
expect(responseData.error).toBe('Registry execution failed')
|
||||
})
|
||||
|
||||
it('should handle old format Redis status (string instead of JSON)', async () => {
|
||||
mockToolRegistryGet.mockReturnValue({ requiresInterrupt: true })
|
||||
|
||||
// Mock Redis to return old format (direct status string)
|
||||
mockRedisGet.mockResolvedValue('accepted')
|
||||
|
||||
const req = new NextRequest('http://localhost:3000/api/copilot/methods', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'x-api-key': 'test-secret-key',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
methodId: 'interrupt-tool',
|
||||
params: {},
|
||||
toolCallId: 'tool-call-old-format',
|
||||
}),
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/copilot/methods/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
const responseData = await response.json()
|
||||
expect(responseData).toEqual({
|
||||
success: true,
|
||||
data: 'Tool executed successfully',
|
||||
})
|
||||
|
||||
expect(mockToolRegistryExecute).toHaveBeenCalled()
|
||||
})
|
||||
describe('copilot methods route placeholder', () => {
|
||||
it('loads test suite', () => {
|
||||
expect(true).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,436 +0,0 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { copilotToolRegistry } from '@/lib/copilot/tools/server-tools/registry'
|
||||
import type { NotificationStatus } from '@/lib/copilot/types'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getRedisClient } from '@/lib/redis'
|
||||
import { createErrorResponse } from '@/app/api/copilot/methods/utils'
|
||||
|
||||
const logger = createLogger('CopilotMethodsAPI')
|
||||
|
||||
/**
|
||||
* Add a tool call to Redis with 'pending' status
|
||||
*/
|
||||
async function addToolToRedis(toolCallId: string): Promise<void> {
|
||||
if (!toolCallId) {
|
||||
logger.warn('addToolToRedis: No tool call ID provided')
|
||||
return
|
||||
}
|
||||
|
||||
const redis = getRedisClient()
|
||||
if (!redis) {
|
||||
logger.warn('addToolToRedis: Redis client not available')
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
const key = `tool_call:${toolCallId}`
|
||||
const status: NotificationStatus = 'pending'
|
||||
|
||||
// Store as JSON object for consistency with confirm API
|
||||
const toolCallData = {
|
||||
status,
|
||||
message: null,
|
||||
timestamp: new Date().toISOString(),
|
||||
}
|
||||
|
||||
// Set with 24 hour expiry (86400 seconds)
|
||||
await redis.set(key, JSON.stringify(toolCallData), 'EX', 86400)
|
||||
|
||||
logger.info('Tool call added to Redis', {
|
||||
toolCallId,
|
||||
key,
|
||||
status,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to add tool call to Redis', {
|
||||
toolCallId,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Poll Redis for tool call status updates
|
||||
* Returns when status changes to 'Accepted' or 'Rejected', or times out after 60 seconds
|
||||
*/
|
||||
async function pollRedisForTool(
|
||||
toolCallId: string
|
||||
): Promise<{ status: NotificationStatus; message?: string; fullData?: any } | null> {
|
||||
const redis = getRedisClient()
|
||||
if (!redis) {
|
||||
logger.warn('pollRedisForTool: Redis client not available')
|
||||
return null
|
||||
}
|
||||
|
||||
const key = `tool_call:${toolCallId}`
|
||||
const timeout = 300000 // 5 minutes
|
||||
const pollInterval = 1000 // 1 second
|
||||
const startTime = Date.now()
|
||||
|
||||
logger.info('Starting to poll Redis for tool call status', {
|
||||
toolCallId,
|
||||
timeout,
|
||||
pollInterval,
|
||||
})
|
||||
|
||||
while (Date.now() - startTime < timeout) {
|
||||
try {
|
||||
const redisValue = await redis.get(key)
|
||||
if (!redisValue) {
|
||||
// Wait before next poll
|
||||
await new Promise((resolve) => setTimeout(resolve, pollInterval))
|
||||
continue
|
||||
}
|
||||
|
||||
let status: NotificationStatus | null = null
|
||||
let message: string | undefined
|
||||
let fullData: any = null
|
||||
|
||||
// Try to parse as JSON (new format), fallback to string (old format)
|
||||
try {
|
||||
const parsedData = JSON.parse(redisValue)
|
||||
status = parsedData.status as NotificationStatus
|
||||
message = parsedData.message || undefined
|
||||
fullData = parsedData // Store the full parsed data
|
||||
} catch {
|
||||
// Fallback to old format (direct status string)
|
||||
status = redisValue as NotificationStatus
|
||||
}
|
||||
|
||||
if (status !== 'pending') {
|
||||
// Log the message found in redis prominently - always log, even if message is null/undefined
|
||||
logger.info('Redis poller found non-pending status', {
|
||||
toolCallId,
|
||||
foundMessage: message,
|
||||
messageType: typeof message,
|
||||
messageIsNull: message === null,
|
||||
messageIsUndefined: message === undefined,
|
||||
status,
|
||||
duration: Date.now() - startTime,
|
||||
rawRedisValue: redisValue,
|
||||
})
|
||||
|
||||
logger.info('Tool call status resolved', {
|
||||
toolCallId,
|
||||
status,
|
||||
message,
|
||||
duration: Date.now() - startTime,
|
||||
rawRedisValue: redisValue,
|
||||
parsedAsJSON: redisValue
|
||||
? (() => {
|
||||
try {
|
||||
return JSON.parse(redisValue)
|
||||
} catch {
|
||||
return 'failed-to-parse'
|
||||
}
|
||||
})()
|
||||
: null,
|
||||
})
|
||||
|
||||
// Special logging for set environment variables tool when Redis status is found
|
||||
if (toolCallId && (status === 'accepted' || status === 'rejected')) {
|
||||
logger.info('SET_ENV_VARS: Redis polling found status update', {
|
||||
toolCallId,
|
||||
foundStatus: status,
|
||||
redisMessage: message,
|
||||
pollDuration: Date.now() - startTime,
|
||||
redisKey: `tool_call:${toolCallId}`,
|
||||
})
|
||||
}
|
||||
|
||||
return { status, message, fullData }
|
||||
}
|
||||
|
||||
// Wait before next poll
|
||||
await new Promise((resolve) => setTimeout(resolve, pollInterval))
|
||||
} catch (error) {
|
||||
logger.error('Error polling Redis for tool call status', {
|
||||
toolCallId,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
logger.warn('Tool call polling timed out', {
|
||||
toolCallId,
|
||||
timeout,
|
||||
})
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle tool calls that require user interruption/approval
|
||||
* Returns { approved: boolean, rejected: boolean, error?: boolean, message?: string } to distinguish between rejection, timeout, and error
|
||||
*/
|
||||
async function interruptHandler(toolCallId: string): Promise<{
|
||||
approved: boolean
|
||||
rejected: boolean
|
||||
error?: boolean
|
||||
message?: string
|
||||
fullData?: any
|
||||
}> {
|
||||
if (!toolCallId) {
|
||||
logger.error('interruptHandler: No tool call ID provided')
|
||||
return { approved: false, rejected: false, error: true, message: 'No tool call ID provided' }
|
||||
}
|
||||
|
||||
logger.info('Starting interrupt handler for tool call', { toolCallId })
|
||||
|
||||
try {
|
||||
// Step 1: Add tool to Redis with 'pending' status
|
||||
await addToolToRedis(toolCallId)
|
||||
|
||||
// Step 2: Poll Redis for status update
|
||||
const result = await pollRedisForTool(toolCallId)
|
||||
|
||||
if (!result) {
|
||||
logger.error('Failed to get tool call status or timed out', { toolCallId })
|
||||
return { approved: false, rejected: false }
|
||||
}
|
||||
|
||||
const { status, message, fullData } = result
|
||||
|
||||
if (status === 'rejected') {
|
||||
logger.info('Tool execution rejected by user', { toolCallId, message })
|
||||
return { approved: false, rejected: true, message, fullData }
|
||||
}
|
||||
|
||||
if (status === 'accepted') {
|
||||
logger.info('Tool execution approved by user', { toolCallId, message })
|
||||
return { approved: true, rejected: false, message, fullData }
|
||||
}
|
||||
|
||||
if (status === 'error') {
|
||||
logger.error('Tool execution failed with error', { toolCallId, message })
|
||||
return { approved: false, rejected: false, error: true, message, fullData }
|
||||
}
|
||||
|
||||
if (status === 'background') {
|
||||
logger.info('Tool execution moved to background', { toolCallId, message })
|
||||
return { approved: true, rejected: false, message, fullData }
|
||||
}
|
||||
|
||||
if (status === 'success') {
|
||||
logger.info('Tool execution completed successfully', { toolCallId, message })
|
||||
return { approved: true, rejected: false, message, fullData }
|
||||
}
|
||||
|
||||
logger.warn('Unexpected tool call status', { toolCallId, status, message })
|
||||
return {
|
||||
approved: false,
|
||||
rejected: false,
|
||||
error: true,
|
||||
message: `Unexpected tool call status: ${status}`,
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
|
||||
logger.error('Error in interrupt handler', {
|
||||
toolCallId,
|
||||
error: errorMessage,
|
||||
})
|
||||
return {
|
||||
approved: false,
|
||||
rejected: false,
|
||||
error: true,
|
||||
message: `Interrupt handler error: ${errorMessage}`,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Schema for method execution
|
||||
const MethodExecutionSchema = z.object({
|
||||
methodId: z.string().min(1, 'Method ID is required'),
|
||||
params: z.record(z.any()).optional().default({}),
|
||||
toolCallId: z.string().nullable().optional().default(null),
|
||||
})
|
||||
|
||||
// Simple internal API key authentication
|
||||
function checkInternalApiKey(req: NextRequest) {
|
||||
const apiKey = req.headers.get('x-api-key')
|
||||
const expectedApiKey = env.INTERNAL_API_SECRET
|
||||
|
||||
if (!expectedApiKey) {
|
||||
return { success: false, error: 'Internal API key not configured' }
|
||||
}
|
||||
|
||||
if (!apiKey) {
|
||||
return { success: false, error: 'API key required' }
|
||||
}
|
||||
|
||||
if (apiKey !== expectedApiKey) {
|
||||
return { success: false, error: 'Invalid API key' }
|
||||
}
|
||||
|
||||
return { success: true }
|
||||
}
|
||||
|
||||
/**
|
||||
* POST /api/copilot/methods
|
||||
* Execute a method based on methodId with internal API key auth
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
const requestId = crypto.randomUUID()
|
||||
const startTime = Date.now()
|
||||
|
||||
try {
|
||||
// Check authentication (internal API key)
|
||||
const authResult = checkInternalApiKey(req)
|
||||
if (!authResult.success) {
|
||||
return NextResponse.json(createErrorResponse(authResult.error || 'Authentication failed'), {
|
||||
status: 401,
|
||||
})
|
||||
}
|
||||
|
||||
const body = await req.json()
|
||||
const { methodId, params, toolCallId } = MethodExecutionSchema.parse(body)
|
||||
|
||||
logger.info(`[${requestId}] Method execution request: ${methodId}`, {
|
||||
methodId,
|
||||
toolCallId,
|
||||
hasParams: !!params && Object.keys(params).length > 0,
|
||||
})
|
||||
|
||||
// Check if tool exists in registry
|
||||
if (!copilotToolRegistry.has(methodId)) {
|
||||
logger.error(`[${requestId}] Tool not found in registry: ${methodId}`, {
|
||||
methodId,
|
||||
toolCallId,
|
||||
availableTools: copilotToolRegistry.getAvailableIds(),
|
||||
registrySize: copilotToolRegistry.getAvailableIds().length,
|
||||
})
|
||||
return NextResponse.json(
|
||||
createErrorResponse(
|
||||
`Unknown method: ${methodId}. Available methods: ${copilotToolRegistry.getAvailableIds().join(', ')}`
|
||||
),
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Tool found in registry: ${methodId}`, {
|
||||
toolCallId,
|
||||
})
|
||||
|
||||
// Check if the tool requires interrupt/approval
|
||||
const tool = copilotToolRegistry.get(methodId)
|
||||
if (tool?.requiresInterrupt) {
|
||||
if (!toolCallId) {
|
||||
logger.warn(`[${requestId}] Tool requires interrupt but no toolCallId provided`, {
|
||||
methodId,
|
||||
})
|
||||
return NextResponse.json(
|
||||
createErrorResponse('This tool requires approval but no tool call ID was provided'),
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Tool requires interrupt, starting approval process`, {
|
||||
methodId,
|
||||
toolCallId,
|
||||
})
|
||||
|
||||
// Handle interrupt flow
|
||||
const { approved, rejected, error, message, fullData } = await interruptHandler(toolCallId)
|
||||
|
||||
if (rejected) {
|
||||
logger.info(`[${requestId}] Tool execution rejected by user`, {
|
||||
methodId,
|
||||
toolCallId,
|
||||
message,
|
||||
})
|
||||
return NextResponse.json(
|
||||
createErrorResponse(
|
||||
'The user decided to skip running this tool. This was a user decision.'
|
||||
),
|
||||
{ status: 200 } // Changed to 200 - user rejection is a valid response
|
||||
)
|
||||
}
|
||||
|
||||
if (error) {
|
||||
logger.error(`[${requestId}] Tool execution failed with error`, {
|
||||
methodId,
|
||||
toolCallId,
|
||||
message,
|
||||
})
|
||||
return NextResponse.json(
|
||||
createErrorResponse(message || 'Tool execution failed with unknown error'),
|
||||
{ status: 500 } // 500 Internal Server Error
|
||||
)
|
||||
}
|
||||
|
||||
if (!approved) {
|
||||
logger.warn(`[${requestId}] Tool execution timed out`, {
|
||||
methodId,
|
||||
toolCallId,
|
||||
})
|
||||
return NextResponse.json(
|
||||
createErrorResponse('Tool execution request timed out'),
|
||||
{ status: 408 } // 408 Request Timeout
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Tool execution approved by user`, {
|
||||
methodId,
|
||||
toolCallId,
|
||||
message,
|
||||
})
|
||||
|
||||
// For tools that need confirmation data, pass the message and/or fullData as parameters
|
||||
if (message) {
|
||||
params.confirmationMessage = message
|
||||
}
|
||||
if (fullData) {
|
||||
params.fullData = fullData
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the tool directly via registry
|
||||
const result = await copilotToolRegistry.execute(methodId, params)
|
||||
|
||||
logger.info(`[${requestId}] Tool execution result:`, {
|
||||
methodId,
|
||||
toolCallId,
|
||||
success: result.success,
|
||||
hasData: !!result.data,
|
||||
hasError: !!result.error,
|
||||
})
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(`[${requestId}] Method execution completed: ${methodId}`, {
|
||||
methodId,
|
||||
toolCallId,
|
||||
duration,
|
||||
success: result.success,
|
||||
})
|
||||
|
||||
return NextResponse.json(result)
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime
|
||||
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.error(`[${requestId}] Request validation error:`, {
|
||||
duration,
|
||||
errors: error.errors,
|
||||
})
|
||||
return NextResponse.json(
|
||||
createErrorResponse(
|
||||
`Invalid request data: ${error.errors.map((e) => e.message).join(', ')}`
|
||||
),
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
logger.error(`[${requestId}] Unexpected error:`, {
|
||||
duration,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
stack: error instanceof Error ? error.stack : undefined,
|
||||
})
|
||||
|
||||
return NextResponse.json(
|
||||
createErrorResponse(error instanceof Error ? error.message : 'Internal server error'),
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
import type { CopilotToolResponse } from '@/lib/copilot/tools/server-tools/base'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
const logger = createLogger('CopilotMethodsUtils')
|
||||
|
||||
/**
|
||||
* Create a standardized error response
|
||||
*/
|
||||
export function createErrorResponse(error: string): CopilotToolResponse {
|
||||
return {
|
||||
success: false,
|
||||
error,
|
||||
}
|
||||
}
|
||||
68
apps/sim/app/api/copilot/stats/route.ts
Normal file
68
apps/sim/app/api/copilot/stats/route.ts
Normal file
@@ -0,0 +1,68 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import {
|
||||
authenticateCopilotRequestSessionOnly,
|
||||
createBadRequestResponse,
|
||||
createInternalServerErrorResponse,
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/auth'
|
||||
import { env } from '@/lib/env'
|
||||
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
|
||||
|
||||
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
|
||||
|
||||
const BodySchema = z.object({
|
||||
messageId: z.string(),
|
||||
diffCreated: z.boolean(),
|
||||
diffAccepted: z.boolean(),
|
||||
})
|
||||
|
||||
export async function POST(req: NextRequest) {
|
||||
const tracker = createRequestTracker()
|
||||
try {
|
||||
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
|
||||
if (!isAuthenticated || !userId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const json = await req.json().catch(() => ({}))
|
||||
const parsed = BodySchema.safeParse(json)
|
||||
if (!parsed.success) {
|
||||
return createBadRequestResponse('Invalid request body for copilot stats')
|
||||
}
|
||||
|
||||
const { messageId, diffCreated, diffAccepted } = parsed.data as any
|
||||
|
||||
// Build outgoing payload for Sim Agent with only required fields
|
||||
const payload: Record<string, any> = {
|
||||
messageId,
|
||||
diffCreated,
|
||||
diffAccepted,
|
||||
}
|
||||
|
||||
const agentRes = await fetch(`${SIM_AGENT_API_URL}/api/stats`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
|
||||
},
|
||||
body: JSON.stringify(payload),
|
||||
})
|
||||
|
||||
// Prefer not to block clients; still relay status
|
||||
let agentJson: any = null
|
||||
try {
|
||||
agentJson = await agentRes.json()
|
||||
} catch {}
|
||||
|
||||
if (!agentRes.ok) {
|
||||
const message = (agentJson && (agentJson.error || agentJson.message)) || 'Upstream error'
|
||||
return NextResponse.json({ success: false, error: message }, { status: 400 })
|
||||
}
|
||||
|
||||
return NextResponse.json({ success: true })
|
||||
} catch (error) {
|
||||
return createInternalServerErrorResponse('Failed to forward copilot stats')
|
||||
}
|
||||
}
|
||||
125
apps/sim/app/api/copilot/tools/mark-complete/route.ts
Normal file
125
apps/sim/app/api/copilot/tools/mark-complete/route.ts
Normal file
@@ -0,0 +1,125 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import {
|
||||
authenticateCopilotRequestSessionOnly,
|
||||
createBadRequestResponse,
|
||||
createInternalServerErrorResponse,
|
||||
createRequestTracker,
|
||||
createUnauthorizedResponse,
|
||||
} from '@/lib/copilot/auth'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { SIM_AGENT_API_URL_DEFAULT } from '@/lib/sim-agent'
|
||||
|
||||
const logger = createLogger('CopilotMarkToolCompleteAPI')
|
||||
|
||||
// Sim Agent API configuration
|
||||
const SIM_AGENT_API_URL = env.SIM_AGENT_API_URL || SIM_AGENT_API_URL_DEFAULT
|
||||
|
||||
// Schema for mark-complete request
|
||||
const MarkCompleteSchema = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
status: z.number().int(),
|
||||
message: z.any().optional(),
|
||||
data: z.any().optional(),
|
||||
})
|
||||
|
||||
/**
|
||||
* POST /api/copilot/tools/mark-complete
|
||||
* Proxy to Sim Agent: POST /api/tools/mark-complete
|
||||
*/
|
||||
export async function POST(req: NextRequest) {
|
||||
const tracker = createRequestTracker()
|
||||
|
||||
try {
|
||||
const { userId, isAuthenticated } = await authenticateCopilotRequestSessionOnly()
|
||||
if (!isAuthenticated || !userId) {
|
||||
return createUnauthorizedResponse()
|
||||
}
|
||||
|
||||
const body = await req.json()
|
||||
|
||||
// Log raw body shape for diagnostics (avoid dumping huge payloads)
|
||||
try {
|
||||
const bodyPreview = JSON.stringify(body).slice(0, 300)
|
||||
logger.debug(`[${tracker.requestId}] Incoming mark-complete raw body preview`, {
|
||||
preview: `${bodyPreview}${bodyPreview.length === 300 ? '...' : ''}`,
|
||||
})
|
||||
} catch {}
|
||||
|
||||
const parsed = MarkCompleteSchema.parse(body)
|
||||
|
||||
const messagePreview = (() => {
|
||||
try {
|
||||
const s =
|
||||
typeof parsed.message === 'string' ? parsed.message : JSON.stringify(parsed.message)
|
||||
return s ? `${s.slice(0, 200)}${s.length > 200 ? '...' : ''}` : undefined
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
})()
|
||||
|
||||
logger.info(`[${tracker.requestId}] Forwarding tool mark-complete`, {
|
||||
userId,
|
||||
toolCallId: parsed.id,
|
||||
toolName: parsed.name,
|
||||
status: parsed.status,
|
||||
hasMessage: parsed.message !== undefined,
|
||||
hasData: parsed.data !== undefined,
|
||||
messagePreview,
|
||||
agentUrl: `${SIM_AGENT_API_URL}/api/tools/mark-complete`,
|
||||
})
|
||||
|
||||
const agentRes = await fetch(`${SIM_AGENT_API_URL}/api/tools/mark-complete`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...(env.COPILOT_API_KEY ? { 'x-api-key': env.COPILOT_API_KEY } : {}),
|
||||
},
|
||||
body: JSON.stringify(parsed),
|
||||
})
|
||||
|
||||
// Attempt to parse agent response JSON
|
||||
let agentJson: any = null
|
||||
let agentText: string | null = null
|
||||
try {
|
||||
agentJson = await agentRes.json()
|
||||
} catch (_) {
|
||||
try {
|
||||
agentText = await agentRes.text()
|
||||
} catch {}
|
||||
}
|
||||
|
||||
logger.info(`[${tracker.requestId}] Agent responded to mark-complete`, {
|
||||
status: agentRes.status,
|
||||
ok: agentRes.ok,
|
||||
responseJsonPreview: agentJson ? JSON.stringify(agentJson).slice(0, 300) : undefined,
|
||||
responseTextPreview: agentText ? agentText.slice(0, 300) : undefined,
|
||||
})
|
||||
|
||||
if (agentRes.ok) {
|
||||
return NextResponse.json({ success: true })
|
||||
}
|
||||
|
||||
const errorMessage =
|
||||
agentJson?.error || agentText || `Agent responded with status ${agentRes.status}`
|
||||
const status = agentRes.status >= 500 ? 500 : 400
|
||||
|
||||
logger.warn(`[${tracker.requestId}] Mark-complete failed`, {
|
||||
status,
|
||||
error: errorMessage,
|
||||
})
|
||||
|
||||
return NextResponse.json({ success: false, error: errorMessage }, { status })
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
logger.warn(`[${tracker.requestId}] Invalid mark-complete request body`, {
|
||||
issues: error.issues,
|
||||
})
|
||||
return createBadRequestResponse('Invalid request body for mark-complete')
|
||||
}
|
||||
logger.error(`[${tracker.requestId}] Failed to proxy mark-complete:`, error)
|
||||
return createInternalServerErrorResponse('Failed to mark tool as complete')
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,6 @@ import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
import { decryptSecret, encryptSecret } from '@/lib/utils'
|
||||
import { db } from '@/db'
|
||||
import { environment } from '@/db/schema'
|
||||
@@ -13,7 +10,6 @@ import type { EnvironmentVariable } from '@/stores/settings/environment/types'
|
||||
|
||||
const logger = createLogger('EnvironmentAPI')
|
||||
|
||||
// Schema for environment variable updates
|
||||
const EnvVarSchema = z.object({
|
||||
variables: z.record(z.string()),
|
||||
})
|
||||
@@ -33,17 +29,13 @@ export async function POST(req: NextRequest) {
|
||||
try {
|
||||
const { variables } = EnvVarSchema.parse(body)
|
||||
|
||||
// Encrypt all variables
|
||||
const encryptedVariables = await Object.entries(variables).reduce(
|
||||
async (accPromise, [key, value]) => {
|
||||
const acc = await accPromise
|
||||
const encryptedVariables = await Promise.all(
|
||||
Object.entries(variables).map(async ([key, value]) => {
|
||||
const { encrypted } = await encryptSecret(value)
|
||||
return { ...acc, [key]: encrypted }
|
||||
},
|
||||
Promise.resolve({})
|
||||
)
|
||||
return [key, encrypted] as const
|
||||
})
|
||||
).then((entries) => Object.fromEntries(entries))
|
||||
|
||||
// Replace all environment variables for user
|
||||
await db
|
||||
.insert(environment)
|
||||
.values({
|
||||
@@ -83,7 +75,6 @@ export async function GET(request: Request) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
// Get the session directly in the API route
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
logger.warn(`[${requestId}] Unauthorized environment variables access attempt`)
|
||||
@@ -102,18 +93,15 @@ export async function GET(request: Request) {
|
||||
return NextResponse.json({ data: {} }, { status: 200 })
|
||||
}
|
||||
|
||||
// Decrypt the variables for client-side use
|
||||
const encryptedVariables = result[0].variables as Record<string, string>
|
||||
const decryptedVariables: Record<string, EnvironmentVariable> = {}
|
||||
|
||||
// Decrypt each variable
|
||||
for (const [key, encryptedValue] of Object.entries(encryptedVariables)) {
|
||||
try {
|
||||
const { decrypted } = await decryptSecret(encryptedValue)
|
||||
decryptedVariables[key] = { key, value: decrypted }
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error decrypting variable ${key}`, error)
|
||||
// If decryption fails, provide a placeholder
|
||||
decryptedVariables[key] = { key, value: '' }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,223 +0,0 @@
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getEnvironmentVariableKeys } from '@/lib/environment/utils'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { decryptSecret, encryptSecret } from '@/lib/utils'
|
||||
import { getUserId } from '@/app/api/auth/oauth/utils'
|
||||
import { db } from '@/db'
|
||||
import { environment } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('EnvironmentVariablesAPI')
|
||||
|
||||
// Schema for environment variable updates
|
||||
const EnvVarSchema = z.object({
|
||||
variables: z.record(z.string()),
|
||||
})
|
||||
|
||||
export async function GET(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
// For GET requests, check for workflowId in query params
|
||||
const { searchParams } = new URL(request.url)
|
||||
const workflowId = searchParams.get('workflowId')
|
||||
|
||||
// Use dual authentication pattern like other copilot tools
|
||||
const userId = await getUserId(requestId, workflowId || undefined)
|
||||
|
||||
if (!userId) {
|
||||
logger.warn(`[${requestId}] Unauthorized environment variables access attempt`)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Get only the variable names (keys), not values
|
||||
const result = await getEnvironmentVariableKeys(userId)
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
output: result,
|
||||
},
|
||||
{ status: 200 }
|
||||
)
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Environment variables fetch error`, error)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: error.message || 'Failed to get environment variables',
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
export async function PUT(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
const { workflowId, variables } = body
|
||||
|
||||
// Use dual authentication pattern like other copilot tools
|
||||
const userId = await getUserId(requestId, workflowId)
|
||||
|
||||
if (!userId) {
|
||||
logger.warn(`[${requestId}] Unauthorized environment variables set attempt`)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
try {
|
||||
const { variables: validatedVariables } = EnvVarSchema.parse({ variables })
|
||||
|
||||
// Get existing environment variables for this user
|
||||
const existingData = await db
|
||||
.select()
|
||||
.from(environment)
|
||||
.where(eq(environment.userId, userId))
|
||||
.limit(1)
|
||||
|
||||
// Start with existing encrypted variables or empty object
|
||||
const existingEncryptedVariables =
|
||||
(existingData[0]?.variables as Record<string, string>) || {}
|
||||
|
||||
// Determine which variables are new or changed by comparing with decrypted existing values
|
||||
const variablesToEncrypt: Record<string, string> = {}
|
||||
const addedVariables: string[] = []
|
||||
const updatedVariables: string[] = []
|
||||
|
||||
for (const [key, newValue] of Object.entries(validatedVariables)) {
|
||||
if (!(key in existingEncryptedVariables)) {
|
||||
// New variable
|
||||
variablesToEncrypt[key] = newValue
|
||||
addedVariables.push(key)
|
||||
} else {
|
||||
// Check if the value has actually changed by decrypting the existing value
|
||||
try {
|
||||
const { decrypted: existingValue } = await decryptSecret(
|
||||
existingEncryptedVariables[key]
|
||||
)
|
||||
|
||||
if (existingValue !== newValue) {
|
||||
// Value changed, needs re-encryption
|
||||
variablesToEncrypt[key] = newValue
|
||||
updatedVariables.push(key)
|
||||
}
|
||||
// If values are the same, keep the existing encrypted value
|
||||
} catch (decryptError) {
|
||||
// If we can't decrypt the existing value, treat as changed and re-encrypt
|
||||
logger.warn(
|
||||
`[${requestId}] Could not decrypt existing variable ${key}, re-encrypting`,
|
||||
{ error: decryptError }
|
||||
)
|
||||
variablesToEncrypt[key] = newValue
|
||||
updatedVariables.push(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only encrypt the variables that are new or changed
|
||||
const newlyEncryptedVariables = await Object.entries(variablesToEncrypt).reduce(
|
||||
async (accPromise, [key, value]) => {
|
||||
const acc = await accPromise
|
||||
const { encrypted } = await encryptSecret(value)
|
||||
return { ...acc, [key]: encrypted }
|
||||
},
|
||||
Promise.resolve({})
|
||||
)
|
||||
|
||||
// Merge existing encrypted variables with newly encrypted ones
|
||||
const finalEncryptedVariables = { ...existingEncryptedVariables, ...newlyEncryptedVariables }
|
||||
|
||||
// Update or insert environment variables for user
|
||||
await db
|
||||
.insert(environment)
|
||||
.values({
|
||||
id: crypto.randomUUID(),
|
||||
userId: userId,
|
||||
variables: finalEncryptedVariables,
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.onConflictDoUpdate({
|
||||
target: [environment.userId],
|
||||
set: {
|
||||
variables: finalEncryptedVariables,
|
||||
updatedAt: new Date(),
|
||||
},
|
||||
})
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
output: {
|
||||
message: `Successfully processed ${Object.keys(validatedVariables).length} environment variable(s): ${addedVariables.length} added, ${updatedVariables.length} updated`,
|
||||
variableCount: Object.keys(validatedVariables).length,
|
||||
variableNames: Object.keys(validatedVariables),
|
||||
totalVariableCount: Object.keys(finalEncryptedVariables).length,
|
||||
addedVariables,
|
||||
updatedVariables,
|
||||
},
|
||||
},
|
||||
{ status: 200 }
|
||||
)
|
||||
} catch (validationError) {
|
||||
if (validationError instanceof z.ZodError) {
|
||||
logger.warn(`[${requestId}] Invalid environment variables data`, {
|
||||
errors: validationError.errors,
|
||||
})
|
||||
return NextResponse.json(
|
||||
{ error: 'Invalid request data', details: validationError.errors },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
throw validationError
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Environment variables set error`, error)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: error.message || 'Failed to set environment variables',
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const body = await request.json()
|
||||
const { workflowId } = body
|
||||
|
||||
// Use dual authentication pattern like other copilot tools
|
||||
const userId = await getUserId(requestId, workflowId)
|
||||
|
||||
if (!userId) {
|
||||
logger.warn(`[${requestId}] Unauthorized environment variables access attempt`)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Get only the variable names (keys), not values
|
||||
const result = await getEnvironmentVariableKeys(userId)
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: true,
|
||||
output: result,
|
||||
},
|
||||
{ status: 200 }
|
||||
)
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] Environment variables fetch error`, error)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: error.message || 'Failed to get environment variables',
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,8 @@
|
||||
import {
|
||||
AbortMultipartUploadCommand,
|
||||
CompleteMultipartUploadCommand,
|
||||
CreateMultipartUploadCommand,
|
||||
UploadPartCommand,
|
||||
} from '@aws-sdk/client-s3'
|
||||
import { getSignedUrl } from '@aws-sdk/s3-request-presigner'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
|
||||
import { S3_KB_CONFIG } from '@/lib/uploads/setup'
|
||||
import { BLOB_KB_CONFIG } from '@/lib/uploads/setup'
|
||||
|
||||
const logger = createLogger('MultipartUploadAPI')
|
||||
|
||||
@@ -26,15 +18,6 @@ interface GetPartUrlsRequest {
|
||||
partNumbers: number[]
|
||||
}
|
||||
|
||||
interface CompleteMultipartRequest {
|
||||
uploadId: string
|
||||
key: string
|
||||
parts: Array<{
|
||||
ETag: string
|
||||
PartNumber: number
|
||||
}>
|
||||
}
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const session = await getSession()
|
||||
@@ -44,106 +27,214 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
const action = request.nextUrl.searchParams.get('action')
|
||||
|
||||
if (!isUsingCloudStorage() || getStorageProvider() !== 's3') {
|
||||
if (!isUsingCloudStorage()) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Multipart upload is only available with S3 storage' },
|
||||
{ error: 'Multipart upload is only available with cloud storage (S3 or Azure Blob)' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const { getS3Client } = await import('@/lib/uploads/s3/s3-client')
|
||||
const s3Client = getS3Client()
|
||||
const storageProvider = getStorageProvider()
|
||||
|
||||
switch (action) {
|
||||
case 'initiate': {
|
||||
const data: InitiateMultipartRequest = await request.json()
|
||||
const { fileName, contentType } = data
|
||||
const { fileName, contentType, fileSize } = data
|
||||
|
||||
const safeFileName = fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
|
||||
const uniqueKey = `kb/${uuidv4()}-${safeFileName}`
|
||||
if (storageProvider === 's3') {
|
||||
const { initiateS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
|
||||
|
||||
const command = new CreateMultipartUploadCommand({
|
||||
Bucket: S3_KB_CONFIG.bucket,
|
||||
Key: uniqueKey,
|
||||
ContentType: contentType,
|
||||
Metadata: {
|
||||
originalName: fileName,
|
||||
uploadedAt: new Date().toISOString(),
|
||||
purpose: 'knowledge-base',
|
||||
},
|
||||
})
|
||||
const result = await initiateS3MultipartUpload({
|
||||
fileName,
|
||||
contentType,
|
||||
fileSize,
|
||||
})
|
||||
|
||||
const response = await s3Client.send(command)
|
||||
logger.info(`Initiated S3 multipart upload for ${fileName}: ${result.uploadId}`)
|
||||
|
||||
logger.info(`Initiated multipart upload for ${fileName}: ${response.UploadId}`)
|
||||
return NextResponse.json({
|
||||
uploadId: result.uploadId,
|
||||
key: result.key,
|
||||
})
|
||||
}
|
||||
if (storageProvider === 'blob') {
|
||||
const { initiateMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
|
||||
|
||||
return NextResponse.json({
|
||||
uploadId: response.UploadId,
|
||||
key: uniqueKey,
|
||||
})
|
||||
const result = await initiateMultipartUpload({
|
||||
fileName,
|
||||
contentType,
|
||||
fileSize,
|
||||
customConfig: {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
},
|
||||
})
|
||||
|
||||
logger.info(`Initiated Azure multipart upload for ${fileName}: ${result.uploadId}`)
|
||||
|
||||
return NextResponse.json({
|
||||
uploadId: result.uploadId,
|
||||
key: result.key,
|
||||
})
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{ error: `Unsupported storage provider: ${storageProvider}` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
case 'get-part-urls': {
|
||||
const data: GetPartUrlsRequest = await request.json()
|
||||
const { uploadId, key, partNumbers } = data
|
||||
|
||||
const presignedUrls = await Promise.all(
|
||||
partNumbers.map(async (partNumber) => {
|
||||
const command = new UploadPartCommand({
|
||||
Bucket: S3_KB_CONFIG.bucket,
|
||||
Key: key,
|
||||
PartNumber: partNumber,
|
||||
UploadId: uploadId,
|
||||
})
|
||||
if (storageProvider === 's3') {
|
||||
const { getS3MultipartPartUrls } = await import('@/lib/uploads/s3/s3-client')
|
||||
|
||||
const url = await getSignedUrl(s3Client, command, { expiresIn: 3600 })
|
||||
return { partNumber, url }
|
||||
const presignedUrls = await getS3MultipartPartUrls(key, uploadId, partNumbers)
|
||||
|
||||
return NextResponse.json({ presignedUrls })
|
||||
}
|
||||
if (storageProvider === 'blob') {
|
||||
const { getMultipartPartUrls } = await import('@/lib/uploads/blob/blob-client')
|
||||
|
||||
const presignedUrls = await getMultipartPartUrls(key, uploadId, partNumbers, {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
})
|
||||
)
|
||||
|
||||
return NextResponse.json({ presignedUrls })
|
||||
return NextResponse.json({ presignedUrls })
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{ error: `Unsupported storage provider: ${storageProvider}` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
case 'complete': {
|
||||
const data: CompleteMultipartRequest = await request.json()
|
||||
const data = await request.json()
|
||||
|
||||
// Handle batch completion
|
||||
if ('uploads' in data) {
|
||||
const results = await Promise.all(
|
||||
data.uploads.map(async (upload: any) => {
|
||||
const { uploadId, key } = upload
|
||||
|
||||
if (storageProvider === 's3') {
|
||||
const { completeS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
|
||||
const parts = upload.parts // S3 format: { ETag, PartNumber }
|
||||
|
||||
const result = await completeS3MultipartUpload(key, uploadId, parts)
|
||||
|
||||
return {
|
||||
success: true,
|
||||
location: result.location,
|
||||
path: result.path,
|
||||
key: result.key,
|
||||
}
|
||||
}
|
||||
if (storageProvider === 'blob') {
|
||||
const { completeMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
|
||||
const parts = upload.parts // Azure format: { blockId, partNumber }
|
||||
|
||||
const result = await completeMultipartUpload(key, uploadId, parts, {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
})
|
||||
|
||||
return {
|
||||
success: true,
|
||||
location: result.location,
|
||||
path: result.path,
|
||||
key: result.key,
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(`Unsupported storage provider: ${storageProvider}`)
|
||||
})
|
||||
)
|
||||
|
||||
logger.info(`Completed ${data.uploads.length} multipart uploads`)
|
||||
return NextResponse.json({ results })
|
||||
}
|
||||
|
||||
// Handle single completion
|
||||
const { uploadId, key, parts } = data
|
||||
|
||||
const command = new CompleteMultipartUploadCommand({
|
||||
Bucket: S3_KB_CONFIG.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
MultipartUpload: {
|
||||
Parts: parts.sort((a, b) => a.PartNumber - b.PartNumber),
|
||||
},
|
||||
})
|
||||
if (storageProvider === 's3') {
|
||||
const { completeS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
|
||||
|
||||
const response = await s3Client.send(command)
|
||||
const result = await completeS3MultipartUpload(key, uploadId, parts)
|
||||
|
||||
logger.info(`Completed multipart upload for key ${key}`)
|
||||
logger.info(`Completed S3 multipart upload for key ${key}`)
|
||||
|
||||
const finalPath = `/api/files/serve/s3/${encodeURIComponent(key)}`
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
location: result.location,
|
||||
path: result.path,
|
||||
key: result.key,
|
||||
})
|
||||
}
|
||||
if (storageProvider === 'blob') {
|
||||
const { completeMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
location: response.Location,
|
||||
path: finalPath,
|
||||
key,
|
||||
})
|
||||
const result = await completeMultipartUpload(key, uploadId, parts, {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
})
|
||||
|
||||
logger.info(`Completed Azure multipart upload for key ${key}`)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
location: result.location,
|
||||
path: result.path,
|
||||
key: result.key,
|
||||
})
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{ error: `Unsupported storage provider: ${storageProvider}` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
case 'abort': {
|
||||
const data = await request.json()
|
||||
const { uploadId, key } = data
|
||||
|
||||
const command = new AbortMultipartUploadCommand({
|
||||
Bucket: S3_KB_CONFIG.bucket,
|
||||
Key: key,
|
||||
UploadId: uploadId,
|
||||
})
|
||||
if (storageProvider === 's3') {
|
||||
const { abortS3MultipartUpload } = await import('@/lib/uploads/s3/s3-client')
|
||||
|
||||
await s3Client.send(command)
|
||||
await abortS3MultipartUpload(key, uploadId)
|
||||
|
||||
logger.info(`Aborted multipart upload for key ${key}`)
|
||||
logger.info(`Aborted S3 multipart upload for key ${key}`)
|
||||
} else if (storageProvider === 'blob') {
|
||||
const { abortMultipartUpload } = await import('@/lib/uploads/blob/blob-client')
|
||||
|
||||
await abortMultipartUpload(key, uploadId, {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
})
|
||||
|
||||
logger.info(`Aborted Azure multipart upload for key ${key}`)
|
||||
} else {
|
||||
return NextResponse.json(
|
||||
{ error: `Unsupported storage provider: ${storageProvider}` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
return NextResponse.json({ success: true })
|
||||
}
|
||||
|
||||
@@ -76,11 +76,9 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
logger.info('File parse request received:', { filePath, fileType })
|
||||
|
||||
// Handle multiple files
|
||||
if (Array.isArray(filePath)) {
|
||||
const results = []
|
||||
for (const path of filePath) {
|
||||
// Skip empty or invalid paths
|
||||
if (!path || (typeof path === 'string' && path.trim() === '')) {
|
||||
results.push({
|
||||
success: false,
|
||||
@@ -91,12 +89,10 @@ export async function POST(request: NextRequest) {
|
||||
}
|
||||
|
||||
const result = await parseFileSingle(path, fileType)
|
||||
// Add processing time to metadata
|
||||
if (result.metadata) {
|
||||
result.metadata.processingTime = Date.now() - startTime
|
||||
}
|
||||
|
||||
// Transform each result to match expected frontend format
|
||||
if (result.success) {
|
||||
results.push({
|
||||
success: true,
|
||||
@@ -105,7 +101,7 @@ export async function POST(request: NextRequest) {
|
||||
name: result.filePath.split('/').pop() || 'unknown',
|
||||
fileType: result.metadata?.fileType || 'application/octet-stream',
|
||||
size: result.metadata?.size || 0,
|
||||
binary: false, // We only return text content
|
||||
binary: false,
|
||||
},
|
||||
filePath: result.filePath,
|
||||
})
|
||||
@@ -120,15 +116,12 @@ export async function POST(request: NextRequest) {
|
||||
})
|
||||
}
|
||||
|
||||
// Handle single file
|
||||
const result = await parseFileSingle(filePath, fileType)
|
||||
|
||||
// Add processing time to metadata
|
||||
if (result.metadata) {
|
||||
result.metadata.processingTime = Date.now() - startTime
|
||||
}
|
||||
|
||||
// Transform single file result to match expected frontend format
|
||||
if (result.success) {
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
@@ -142,8 +135,6 @@ export async function POST(request: NextRequest) {
|
||||
})
|
||||
}
|
||||
|
||||
// Only return 500 for actual server errors, not file processing failures
|
||||
// File processing failures (like file not found, parsing errors) should return 200 with success:false
|
||||
return NextResponse.json(result)
|
||||
} catch (error) {
|
||||
logger.error('Error in file parse API:', error)
|
||||
@@ -164,7 +155,6 @@ export async function POST(request: NextRequest) {
|
||||
async function parseFileSingle(filePath: string, fileType?: string): Promise<ParseResult> {
|
||||
logger.info('Parsing file:', filePath)
|
||||
|
||||
// Validate that filePath is not empty
|
||||
if (!filePath || filePath.trim() === '') {
|
||||
return {
|
||||
success: false,
|
||||
@@ -173,7 +163,6 @@ async function parseFileSingle(filePath: string, fileType?: string): Promise<Par
|
||||
}
|
||||
}
|
||||
|
||||
// Validate path for security before any processing
|
||||
const pathValidation = validateFilePath(filePath)
|
||||
if (!pathValidation.isValid) {
|
||||
return {
|
||||
@@ -183,49 +172,40 @@ async function parseFileSingle(filePath: string, fileType?: string): Promise<Par
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is an external URL
|
||||
if (filePath.startsWith('http://') || filePath.startsWith('https://')) {
|
||||
return handleExternalUrl(filePath, fileType)
|
||||
}
|
||||
|
||||
// Check if this is a cloud storage path (S3 or Blob)
|
||||
const isS3Path = filePath.includes('/api/files/serve/s3/')
|
||||
const isBlobPath = filePath.includes('/api/files/serve/blob/')
|
||||
|
||||
// Use cloud handler if it's a cloud path or we're in cloud mode
|
||||
if (isS3Path || isBlobPath || isUsingCloudStorage()) {
|
||||
return handleCloudFile(filePath, fileType)
|
||||
}
|
||||
|
||||
// Use local handler for local files
|
||||
return handleLocalFile(filePath, fileType)
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate file path for security
|
||||
* Validate file path for security - prevents null byte injection and path traversal attacks
|
||||
*/
|
||||
function validateFilePath(filePath: string): { isValid: boolean; error?: string } {
|
||||
// Check for null bytes
|
||||
if (filePath.includes('\0')) {
|
||||
return { isValid: false, error: 'Invalid path: null byte detected' }
|
||||
}
|
||||
|
||||
// Check for path traversal attempts
|
||||
if (filePath.includes('..')) {
|
||||
return { isValid: false, error: 'Access denied: path traversal detected' }
|
||||
}
|
||||
|
||||
// Check for tilde characters (home directory access)
|
||||
if (filePath.includes('~')) {
|
||||
return { isValid: false, error: 'Invalid path: tilde character not allowed' }
|
||||
}
|
||||
|
||||
// Check for absolute paths outside allowed directories
|
||||
if (filePath.startsWith('/') && !filePath.startsWith('/api/files/serve/')) {
|
||||
return { isValid: false, error: 'Path outside allowed directory' }
|
||||
}
|
||||
|
||||
// Check for Windows absolute paths
|
||||
if (/^[A-Za-z]:\\/.test(filePath)) {
|
||||
return { isValid: false, error: 'Path outside allowed directory' }
|
||||
}
|
||||
@@ -260,12 +240,10 @@ async function handleExternalUrl(url: string, fileType?: string): Promise<ParseR
|
||||
|
||||
logger.info(`Downloaded file from URL: ${url}, size: ${buffer.length} bytes`)
|
||||
|
||||
// Extract filename from URL
|
||||
const urlPath = new URL(url).pathname
|
||||
const filename = urlPath.split('/').pop() || 'download'
|
||||
const extension = path.extname(filename).toLowerCase().substring(1)
|
||||
|
||||
// Process the file based on its content type
|
||||
if (extension === 'pdf') {
|
||||
return await handlePdfBuffer(buffer, filename, fileType, url)
|
||||
}
|
||||
@@ -276,7 +254,6 @@ async function handleExternalUrl(url: string, fileType?: string): Promise<ParseR
|
||||
return await handleGenericTextBuffer(buffer, filename, extension, fileType, url)
|
||||
}
|
||||
|
||||
// For binary or unknown files
|
||||
return handleGenericBuffer(buffer, filename, extension, fileType)
|
||||
} catch (error) {
|
||||
logger.error(`Error handling external URL ${url}:`, error)
|
||||
@@ -289,35 +266,29 @@ async function handleExternalUrl(url: string, fileType?: string): Promise<ParseR
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle file stored in cloud storage (S3 or Azure Blob)
|
||||
* Handle file stored in cloud storage
|
||||
*/
|
||||
async function handleCloudFile(filePath: string, fileType?: string): Promise<ParseResult> {
|
||||
try {
|
||||
// Extract the cloud key from the path
|
||||
let cloudKey: string
|
||||
if (filePath.includes('/api/files/serve/s3/')) {
|
||||
cloudKey = decodeURIComponent(filePath.split('/api/files/serve/s3/')[1])
|
||||
} else if (filePath.includes('/api/files/serve/blob/')) {
|
||||
cloudKey = decodeURIComponent(filePath.split('/api/files/serve/blob/')[1])
|
||||
} else if (filePath.startsWith('/api/files/serve/')) {
|
||||
// Backwards-compatibility: path like "/api/files/serve/<key>"
|
||||
cloudKey = decodeURIComponent(filePath.substring('/api/files/serve/'.length))
|
||||
} else {
|
||||
// Assume raw key provided
|
||||
cloudKey = filePath
|
||||
}
|
||||
|
||||
logger.info('Extracted cloud key:', cloudKey)
|
||||
|
||||
// Download the file from cloud storage - this can throw for access errors
|
||||
const fileBuffer = await downloadFile(cloudKey)
|
||||
logger.info(`Downloaded file from cloud storage: ${cloudKey}, size: ${fileBuffer.length} bytes`)
|
||||
|
||||
// Extract the filename from the cloud key
|
||||
const filename = cloudKey.split('/').pop() || cloudKey
|
||||
const extension = path.extname(filename).toLowerCase().substring(1)
|
||||
|
||||
// Process the file based on its content type
|
||||
if (extension === 'pdf') {
|
||||
return await handlePdfBuffer(fileBuffer, filename, fileType, filePath)
|
||||
}
|
||||
@@ -325,22 +296,19 @@ async function handleCloudFile(filePath: string, fileType?: string): Promise<Par
|
||||
return await handleCsvBuffer(fileBuffer, filename, fileType, filePath)
|
||||
}
|
||||
if (isSupportedFileType(extension)) {
|
||||
// For other supported types that we have parsers for
|
||||
return await handleGenericTextBuffer(fileBuffer, filename, extension, fileType, filePath)
|
||||
}
|
||||
// For binary or unknown files
|
||||
return handleGenericBuffer(fileBuffer, filename, extension, fileType)
|
||||
} catch (error) {
|
||||
logger.error(`Error handling cloud file ${filePath}:`, error)
|
||||
|
||||
// Check if this is a download/access error that should trigger a 500 response
|
||||
// For download/access errors, throw to trigger 500 response
|
||||
const errorMessage = (error as Error).message
|
||||
if (errorMessage.includes('Access denied') || errorMessage.includes('Forbidden')) {
|
||||
// For access errors, throw to trigger 500 response
|
||||
throw new Error(`Error accessing file from cloud storage: ${errorMessage}`)
|
||||
}
|
||||
|
||||
// For other errors (parsing, processing), return success:false
|
||||
// For other errors (parsing, processing), return success:false and an error message
|
||||
return {
|
||||
success: false,
|
||||
error: `Error accessing file from cloud storage: ${errorMessage}`,
|
||||
@@ -354,28 +322,23 @@ async function handleCloudFile(filePath: string, fileType?: string): Promise<Par
|
||||
*/
|
||||
async function handleLocalFile(filePath: string, fileType?: string): Promise<ParseResult> {
|
||||
try {
|
||||
// Extract filename from path
|
||||
const filename = filePath.split('/').pop() || filePath
|
||||
const fullPath = path.join(UPLOAD_DIR_SERVER, filename)
|
||||
|
||||
logger.info('Processing local file:', fullPath)
|
||||
|
||||
// Check if file exists
|
||||
try {
|
||||
await fsPromises.access(fullPath)
|
||||
} catch {
|
||||
throw new Error(`File not found: ${filename}`)
|
||||
}
|
||||
|
||||
// Parse the file directly
|
||||
const result = await parseFile(fullPath)
|
||||
|
||||
// Get file stats for metadata
|
||||
const stats = await fsPromises.stat(fullPath)
|
||||
const fileBuffer = await readFile(fullPath)
|
||||
const hash = createHash('md5').update(fileBuffer).digest('hex')
|
||||
|
||||
// Extract file extension for type detection
|
||||
const extension = path.extname(filename).toLowerCase().substring(1)
|
||||
|
||||
return {
|
||||
@@ -386,7 +349,7 @@ async function handleLocalFile(filePath: string, fileType?: string): Promise<Par
|
||||
fileType: fileType || getMimeType(extension),
|
||||
size: stats.size,
|
||||
hash,
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -425,15 +388,14 @@ async function handlePdfBuffer(
|
||||
fileType: fileType || 'application/pdf',
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to parse PDF in memory:', error)
|
||||
|
||||
// Create fallback message for PDF parsing failure
|
||||
const content = createPdfFailureMessage(
|
||||
0, // We can't determine page count without parsing
|
||||
0,
|
||||
fileBuffer.length,
|
||||
originalPath || filename,
|
||||
(error as Error).message
|
||||
@@ -447,7 +409,7 @@ async function handlePdfBuffer(
|
||||
fileType: fileType || 'application/pdf',
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -465,7 +427,6 @@ async function handleCsvBuffer(
|
||||
try {
|
||||
logger.info(`Parsing CSV in memory: ${filename}`)
|
||||
|
||||
// Use the parseBuffer function from our library
|
||||
const { parseBuffer } = await import('@/lib/file-parsers')
|
||||
const result = await parseBuffer(fileBuffer, 'csv')
|
||||
|
||||
@@ -477,7 +438,7 @@ async function handleCsvBuffer(
|
||||
fileType: fileType || 'text/csv',
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -490,7 +451,7 @@ async function handleCsvBuffer(
|
||||
fileType: 'text/csv',
|
||||
size: 0,
|
||||
hash: '',
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -509,7 +470,6 @@ async function handleGenericTextBuffer(
|
||||
try {
|
||||
logger.info(`Parsing text file in memory: ${filename}`)
|
||||
|
||||
// Try to use a specialized parser if available
|
||||
try {
|
||||
const { parseBuffer, isSupportedFileType } = await import('@/lib/file-parsers')
|
||||
|
||||
@@ -524,7 +484,7 @@ async function handleGenericTextBuffer(
|
||||
fileType: fileType || getMimeType(extension),
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -532,7 +492,6 @@ async function handleGenericTextBuffer(
|
||||
logger.warn('Specialized parser failed, falling back to generic parsing:', parserError)
|
||||
}
|
||||
|
||||
// Fallback to generic text parsing
|
||||
const content = fileBuffer.toString('utf-8')
|
||||
|
||||
return {
|
||||
@@ -543,7 +502,7 @@ async function handleGenericTextBuffer(
|
||||
fileType: fileType || getMimeType(extension),
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -556,7 +515,7 @@ async function handleGenericTextBuffer(
|
||||
fileType: 'text/plain',
|
||||
size: 0,
|
||||
hash: '',
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -584,7 +543,7 @@ function handleGenericBuffer(
|
||||
fileType: fileType || getMimeType(extension),
|
||||
size: fileBuffer.length,
|
||||
hash: createHash('md5').update(fileBuffer).digest('hex'),
|
||||
processingTime: 0, // Will be set by caller
|
||||
processingTime: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -594,25 +553,11 @@ function handleGenericBuffer(
|
||||
*/
|
||||
async function parseBufferAsPdf(buffer: Buffer) {
|
||||
try {
|
||||
// Import parsers dynamically to avoid initialization issues in tests
|
||||
// First try to use the main PDF parser
|
||||
try {
|
||||
const { PdfParser } = await import('@/lib/file-parsers/pdf-parser')
|
||||
const parser = new PdfParser()
|
||||
logger.info('Using main PDF parser for buffer')
|
||||
const { PdfParser } = await import('@/lib/file-parsers/pdf-parser')
|
||||
const parser = new PdfParser()
|
||||
logger.info('Using main PDF parser for buffer')
|
||||
|
||||
if (parser.parseBuffer) {
|
||||
return await parser.parseBuffer(buffer)
|
||||
}
|
||||
throw new Error('PDF parser does not support buffer parsing')
|
||||
} catch (error) {
|
||||
// Fallback to raw PDF parser
|
||||
logger.warn('Main PDF parser failed, using raw parser for buffer:', error)
|
||||
const { RawPdfParser } = await import('@/lib/file-parsers/raw-pdf-parser')
|
||||
const rawParser = new RawPdfParser()
|
||||
|
||||
return await rawParser.parseBuffer(buffer)
|
||||
}
|
||||
return await parser.parseBuffer(buffer)
|
||||
} catch (error) {
|
||||
throw new Error(`PDF parsing failed: ${(error as Error).message}`)
|
||||
}
|
||||
@@ -655,7 +600,7 @@ Please use a PDF viewer for best results.`
|
||||
}
|
||||
|
||||
/**
|
||||
* Create error message for PDF parsing failure
|
||||
* Create error message for PDF parsing failure and make it more readable
|
||||
*/
|
||||
function createPdfFailureMessage(
|
||||
pageCount: number,
|
||||
|
||||
361
apps/sim/app/api/files/presigned/batch/route.ts
Normal file
361
apps/sim/app/api/files/presigned/batch/route.ts
Normal file
@@ -0,0 +1,361 @@
|
||||
import { PutObjectCommand } from '@aws-sdk/client-s3'
|
||||
import { getSignedUrl } from '@aws-sdk/s3-request-presigner'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
|
||||
import {
|
||||
BLOB_CHAT_CONFIG,
|
||||
BLOB_CONFIG,
|
||||
BLOB_COPILOT_CONFIG,
|
||||
BLOB_KB_CONFIG,
|
||||
S3_CHAT_CONFIG,
|
||||
S3_CONFIG,
|
||||
S3_COPILOT_CONFIG,
|
||||
S3_KB_CONFIG,
|
||||
} from '@/lib/uploads/setup'
|
||||
import { validateFileType } from '@/lib/uploads/validation'
|
||||
import { createErrorResponse, createOptionsResponse } from '@/app/api/files/utils'
|
||||
|
||||
const logger = createLogger('BatchPresignedUploadAPI')
|
||||
|
||||
interface BatchFileRequest {
|
||||
fileName: string
|
||||
contentType: string
|
||||
fileSize: number
|
||||
}
|
||||
|
||||
interface BatchPresignedUrlRequest {
|
||||
files: BatchFileRequest[]
|
||||
}
|
||||
|
||||
type UploadType = 'general' | 'knowledge-base' | 'chat' | 'copilot'
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
let data: BatchPresignedUrlRequest
|
||||
try {
|
||||
data = await request.json()
|
||||
} catch {
|
||||
return NextResponse.json({ error: 'Invalid JSON in request body' }, { status: 400 })
|
||||
}
|
||||
|
||||
const { files } = data
|
||||
|
||||
if (!files || !Array.isArray(files) || files.length === 0) {
|
||||
return NextResponse.json(
|
||||
{ error: 'files array is required and cannot be empty' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (files.length > 100) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Cannot process more than 100 files at once' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const uploadTypeParam = request.nextUrl.searchParams.get('type')
|
||||
const uploadType: UploadType =
|
||||
uploadTypeParam === 'knowledge-base'
|
||||
? 'knowledge-base'
|
||||
: uploadTypeParam === 'chat'
|
||||
? 'chat'
|
||||
: uploadTypeParam === 'copilot'
|
||||
? 'copilot'
|
||||
: 'general'
|
||||
|
||||
const MAX_FILE_SIZE = 100 * 1024 * 1024
|
||||
for (const file of files) {
|
||||
if (!file.fileName?.trim()) {
|
||||
return NextResponse.json({ error: 'fileName is required for all files' }, { status: 400 })
|
||||
}
|
||||
if (!file.contentType?.trim()) {
|
||||
return NextResponse.json(
|
||||
{ error: 'contentType is required for all files' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
if (!file.fileSize || file.fileSize <= 0) {
|
||||
return NextResponse.json(
|
||||
{ error: 'fileSize must be positive for all files' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
if (file.fileSize > MAX_FILE_SIZE) {
|
||||
return NextResponse.json(
|
||||
{ error: `File ${file.fileName} exceeds maximum size of ${MAX_FILE_SIZE} bytes` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (uploadType === 'knowledge-base') {
|
||||
const fileValidationError = validateFileType(file.fileName, file.contentType)
|
||||
if (fileValidationError) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: fileValidationError.message,
|
||||
code: fileValidationError.code,
|
||||
supportedTypes: fileValidationError.supportedTypes,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const sessionUserId = session.user.id
|
||||
|
||||
if (uploadType === 'copilot' && !sessionUserId?.trim()) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Authenticated user session is required for copilot uploads' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (!isUsingCloudStorage()) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Direct uploads are only available when cloud storage is enabled' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
const storageProvider = getStorageProvider()
|
||||
logger.info(
|
||||
`Generating batch ${uploadType} presigned URLs for ${files.length} files using ${storageProvider}`
|
||||
)
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
let result
|
||||
switch (storageProvider) {
|
||||
case 's3':
|
||||
result = await handleBatchS3PresignedUrls(files, uploadType, sessionUserId)
|
||||
break
|
||||
case 'blob':
|
||||
result = await handleBatchBlobPresignedUrls(files, uploadType, sessionUserId)
|
||||
break
|
||||
default:
|
||||
return NextResponse.json(
|
||||
{ error: `Unknown storage provider: ${storageProvider}` },
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
logger.info(
|
||||
`Generated ${files.length} presigned URLs in ${duration}ms (avg ${Math.round(duration / files.length)}ms per file)`
|
||||
)
|
||||
|
||||
return NextResponse.json(result)
|
||||
} catch (error) {
|
||||
logger.error('Error generating batch presigned URLs:', error)
|
||||
return createErrorResponse(
|
||||
error instanceof Error ? error : new Error('Failed to generate batch presigned URLs')
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async function handleBatchS3PresignedUrls(
|
||||
files: BatchFileRequest[],
|
||||
uploadType: UploadType,
|
||||
userId?: string
|
||||
) {
|
||||
const config =
|
||||
uploadType === 'knowledge-base'
|
||||
? S3_KB_CONFIG
|
||||
: uploadType === 'chat'
|
||||
? S3_CHAT_CONFIG
|
||||
: uploadType === 'copilot'
|
||||
? S3_COPILOT_CONFIG
|
||||
: S3_CONFIG
|
||||
|
||||
if (!config.bucket || !config.region) {
|
||||
throw new Error(`S3 configuration missing for ${uploadType} uploads`)
|
||||
}
|
||||
|
||||
const { getS3Client, sanitizeFilenameForMetadata } = await import('@/lib/uploads/s3/s3-client')
|
||||
const s3Client = getS3Client()
|
||||
|
||||
let prefix = ''
|
||||
if (uploadType === 'knowledge-base') {
|
||||
prefix = 'kb/'
|
||||
} else if (uploadType === 'chat') {
|
||||
prefix = 'chat/'
|
||||
} else if (uploadType === 'copilot') {
|
||||
prefix = `${userId}/`
|
||||
}
|
||||
|
||||
const baseMetadata: Record<string, string> = {
|
||||
uploadedAt: new Date().toISOString(),
|
||||
}
|
||||
|
||||
if (uploadType === 'knowledge-base') {
|
||||
baseMetadata.purpose = 'knowledge-base'
|
||||
} else if (uploadType === 'chat') {
|
||||
baseMetadata.purpose = 'chat'
|
||||
} else if (uploadType === 'copilot') {
|
||||
baseMetadata.purpose = 'copilot'
|
||||
baseMetadata.userId = userId || ''
|
||||
}
|
||||
|
||||
const results = await Promise.all(
|
||||
files.map(async (file) => {
|
||||
const safeFileName = file.fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
|
||||
const uniqueKey = `${prefix}${uuidv4()}-${safeFileName}`
|
||||
const sanitizedOriginalName = sanitizeFilenameForMetadata(file.fileName)
|
||||
|
||||
const metadata = {
|
||||
...baseMetadata,
|
||||
originalName: sanitizedOriginalName,
|
||||
}
|
||||
|
||||
const command = new PutObjectCommand({
|
||||
Bucket: config.bucket,
|
||||
Key: uniqueKey,
|
||||
ContentType: file.contentType,
|
||||
Metadata: metadata,
|
||||
})
|
||||
|
||||
const presignedUrl = await getSignedUrl(s3Client, command, { expiresIn: 3600 })
|
||||
|
||||
const finalPath =
|
||||
uploadType === 'chat'
|
||||
? `https://${config.bucket}.s3.${config.region}.amazonaws.com/${uniqueKey}`
|
||||
: `/api/files/serve/s3/${encodeURIComponent(uniqueKey)}`
|
||||
|
||||
return {
|
||||
fileName: file.fileName,
|
||||
presignedUrl,
|
||||
fileInfo: {
|
||||
path: finalPath,
|
||||
key: uniqueKey,
|
||||
name: file.fileName,
|
||||
size: file.fileSize,
|
||||
type: file.contentType,
|
||||
},
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
files: results,
|
||||
directUploadSupported: true,
|
||||
}
|
||||
}
|
||||
|
||||
async function handleBatchBlobPresignedUrls(
|
||||
files: BatchFileRequest[],
|
||||
uploadType: UploadType,
|
||||
userId?: string
|
||||
) {
|
||||
const config =
|
||||
uploadType === 'knowledge-base'
|
||||
? BLOB_KB_CONFIG
|
||||
: uploadType === 'chat'
|
||||
? BLOB_CHAT_CONFIG
|
||||
: uploadType === 'copilot'
|
||||
? BLOB_COPILOT_CONFIG
|
||||
: BLOB_CONFIG
|
||||
|
||||
if (
|
||||
!config.accountName ||
|
||||
!config.containerName ||
|
||||
(!config.accountKey && !config.connectionString)
|
||||
) {
|
||||
throw new Error(`Azure Blob configuration missing for ${uploadType} uploads`)
|
||||
}
|
||||
|
||||
const { getBlobServiceClient } = await import('@/lib/uploads/blob/blob-client')
|
||||
const { BlobSASPermissions, generateBlobSASQueryParameters, StorageSharedKeyCredential } =
|
||||
await import('@azure/storage-blob')
|
||||
|
||||
const blobServiceClient = getBlobServiceClient()
|
||||
const containerClient = blobServiceClient.getContainerClient(config.containerName)
|
||||
|
||||
let prefix = ''
|
||||
if (uploadType === 'knowledge-base') {
|
||||
prefix = 'kb/'
|
||||
} else if (uploadType === 'chat') {
|
||||
prefix = 'chat/'
|
||||
} else if (uploadType === 'copilot') {
|
||||
prefix = `${userId}/`
|
||||
}
|
||||
|
||||
const baseUploadHeaders: Record<string, string> = {
|
||||
'x-ms-blob-type': 'BlockBlob',
|
||||
'x-ms-meta-uploadedat': new Date().toISOString(),
|
||||
}
|
||||
|
||||
if (uploadType === 'knowledge-base') {
|
||||
baseUploadHeaders['x-ms-meta-purpose'] = 'knowledge-base'
|
||||
} else if (uploadType === 'chat') {
|
||||
baseUploadHeaders['x-ms-meta-purpose'] = 'chat'
|
||||
} else if (uploadType === 'copilot') {
|
||||
baseUploadHeaders['x-ms-meta-purpose'] = 'copilot'
|
||||
baseUploadHeaders['x-ms-meta-userid'] = encodeURIComponent(userId || '')
|
||||
}
|
||||
|
||||
const results = await Promise.all(
|
||||
files.map(async (file) => {
|
||||
const safeFileName = file.fileName.replace(/\s+/g, '-').replace(/[^a-zA-Z0-9.-]/g, '_')
|
||||
const uniqueKey = `${prefix}${uuidv4()}-${safeFileName}`
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(uniqueKey)
|
||||
|
||||
const sasOptions = {
|
||||
containerName: config.containerName,
|
||||
blobName: uniqueKey,
|
||||
permissions: BlobSASPermissions.parse('w'),
|
||||
startsOn: new Date(),
|
||||
expiresOn: new Date(Date.now() + 3600 * 1000),
|
||||
}
|
||||
|
||||
const sasToken = generateBlobSASQueryParameters(
|
||||
sasOptions,
|
||||
new StorageSharedKeyCredential(config.accountName, config.accountKey || '')
|
||||
).toString()
|
||||
|
||||
const presignedUrl = `${blockBlobClient.url}?${sasToken}`
|
||||
|
||||
const finalPath =
|
||||
uploadType === 'chat'
|
||||
? blockBlobClient.url
|
||||
: `/api/files/serve/blob/${encodeURIComponent(uniqueKey)}`
|
||||
|
||||
const uploadHeaders = {
|
||||
...baseUploadHeaders,
|
||||
'x-ms-blob-content-type': file.contentType,
|
||||
'x-ms-meta-originalname': encodeURIComponent(file.fileName),
|
||||
}
|
||||
|
||||
return {
|
||||
fileName: file.fileName,
|
||||
presignedUrl,
|
||||
fileInfo: {
|
||||
path: finalPath,
|
||||
key: uniqueKey,
|
||||
name: file.fileName,
|
||||
size: file.fileSize,
|
||||
type: file.contentType,
|
||||
},
|
||||
uploadHeaders,
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
files: results,
|
||||
directUploadSupported: true,
|
||||
}
|
||||
}
|
||||
|
||||
export async function OPTIONS() {
|
||||
return createOptionsResponse()
|
||||
}
|
||||
@@ -1,7 +1,13 @@
|
||||
import { NextRequest } from 'next/server'
|
||||
import { beforeEach, describe, expect, test, vi } from 'vitest'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { setupFileApiMocks } from '@/app/api/__test-utils__/utils'
|
||||
|
||||
/**
|
||||
* Tests for file presigned API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
|
||||
describe('/api/files/presigned', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -19,7 +25,7 @@ describe('/api/files/presigned', () => {
|
||||
})
|
||||
|
||||
describe('POST', () => {
|
||||
test('should return error when cloud storage is not enabled', async () => {
|
||||
it('should return error when cloud storage is not enabled', async () => {
|
||||
setupFileApiMocks({
|
||||
cloudEnabled: false,
|
||||
storageProvider: 's3',
|
||||
@@ -39,7 +45,7 @@ describe('/api/files/presigned', () => {
|
||||
const response = await POST(request)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(500) // Changed from 400 to 500 (StorageConfigError)
|
||||
expect(response.status).toBe(500)
|
||||
expect(data.error).toBe('Direct uploads are only available when cloud storage is enabled')
|
||||
expect(data.code).toBe('STORAGE_CONFIG_ERROR')
|
||||
expect(data.directUploadSupported).toBe(false)
|
||||
|
||||
@@ -5,6 +5,7 @@ import { v4 as uuidv4 } from 'uuid'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
|
||||
import { isImageFileType } from '@/lib/uploads/file-utils'
|
||||
// Dynamic imports for storage clients to avoid client-side bundling
|
||||
import {
|
||||
BLOB_CHAT_CONFIG,
|
||||
@@ -16,6 +17,7 @@ import {
|
||||
S3_COPILOT_CONFIG,
|
||||
S3_KB_CONFIG,
|
||||
} from '@/lib/uploads/setup'
|
||||
import { validateFileType } from '@/lib/uploads/validation'
|
||||
import { createErrorResponse, createOptionsResponse } from '@/app/api/files/utils'
|
||||
|
||||
const logger = createLogger('PresignedUploadAPI')
|
||||
@@ -96,6 +98,13 @@ export async function POST(request: NextRequest) {
|
||||
? 'copilot'
|
||||
: 'general'
|
||||
|
||||
if (uploadType === 'knowledge-base') {
|
||||
const fileValidationError = validateFileType(fileName, contentType)
|
||||
if (fileValidationError) {
|
||||
throw new ValidationError(`${fileValidationError.message}`)
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate user id from session for copilot uploads
|
||||
const sessionUserId = session.user.id
|
||||
|
||||
@@ -104,6 +113,12 @@ export async function POST(request: NextRequest) {
|
||||
if (!sessionUserId?.trim()) {
|
||||
throw new ValidationError('Authenticated user session is required for copilot uploads')
|
||||
}
|
||||
// Only allow image uploads for copilot
|
||||
if (!isImageFileType(contentType)) {
|
||||
throw new ValidationError(
|
||||
'Only image files (JPEG, PNG, GIF, WebP, SVG) are allowed for copilot uploads'
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (!isUsingCloudStorage()) {
|
||||
@@ -224,10 +239,9 @@ async function handleS3PresignedUrl(
|
||||
)
|
||||
}
|
||||
|
||||
// For chat images, use direct S3 URLs since they need to be permanently accessible
|
||||
// For other files, use serve path for access control
|
||||
// For chat images and knowledge base files, use direct URLs since they need to be accessible by external services
|
||||
const finalPath =
|
||||
uploadType === 'chat'
|
||||
uploadType === 'chat' || uploadType === 'knowledge-base'
|
||||
? `https://${config.bucket}.s3.${config.region}.amazonaws.com/${uniqueKey}`
|
||||
: `/api/files/serve/s3/${encodeURIComponent(uniqueKey)}`
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import { readFile } from 'fs/promises'
|
||||
import type { NextRequest, NextResponse } from 'next/server'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { downloadFile, getStorageProvider, isUsingCloudStorage } from '@/lib/uploads'
|
||||
import { BLOB_KB_CONFIG, S3_KB_CONFIG } from '@/lib/uploads/setup'
|
||||
import { S3_KB_CONFIG } from '@/lib/uploads/setup'
|
||||
import '@/lib/uploads/setup.server'
|
||||
|
||||
import {
|
||||
@@ -15,19 +15,6 @@ import {
|
||||
|
||||
const logger = createLogger('FilesServeAPI')
|
||||
|
||||
async function streamToBuffer(readableStream: NodeJS.ReadableStream): Promise<Buffer> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const chunks: Buffer[] = []
|
||||
readableStream.on('data', (data) => {
|
||||
chunks.push(data instanceof Buffer ? data : Buffer.from(data))
|
||||
})
|
||||
readableStream.on('end', () => {
|
||||
resolve(Buffer.concat(chunks))
|
||||
})
|
||||
readableStream.on('error', reject)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Main API route handler for serving files
|
||||
*/
|
||||
@@ -102,49 +89,23 @@ async function handleLocalFile(filename: string): Promise<NextResponse> {
|
||||
}
|
||||
|
||||
async function downloadKBFile(cloudKey: string): Promise<Buffer> {
|
||||
logger.info(`Downloading KB file: ${cloudKey}`)
|
||||
const storageProvider = getStorageProvider()
|
||||
|
||||
if (storageProvider === 'blob') {
|
||||
logger.info(`Downloading KB file from Azure Blob Storage: ${cloudKey}`)
|
||||
// Use KB-specific blob configuration
|
||||
const { getBlobServiceClient } = await import('@/lib/uploads/blob/blob-client')
|
||||
const blobServiceClient = getBlobServiceClient()
|
||||
const containerClient = blobServiceClient.getContainerClient(BLOB_KB_CONFIG.containerName)
|
||||
const blockBlobClient = containerClient.getBlockBlobClient(cloudKey)
|
||||
|
||||
const downloadBlockBlobResponse = await blockBlobClient.download()
|
||||
if (!downloadBlockBlobResponse.readableStreamBody) {
|
||||
throw new Error('Failed to get readable stream from blob download')
|
||||
}
|
||||
|
||||
// Convert stream to buffer
|
||||
return await streamToBuffer(downloadBlockBlobResponse.readableStreamBody)
|
||||
const { BLOB_KB_CONFIG } = await import('@/lib/uploads/setup')
|
||||
return downloadFile(cloudKey, {
|
||||
containerName: BLOB_KB_CONFIG.containerName,
|
||||
accountName: BLOB_KB_CONFIG.accountName,
|
||||
accountKey: BLOB_KB_CONFIG.accountKey,
|
||||
connectionString: BLOB_KB_CONFIG.connectionString,
|
||||
})
|
||||
}
|
||||
|
||||
if (storageProvider === 's3') {
|
||||
logger.info(`Downloading KB file from S3: ${cloudKey}`)
|
||||
// Use KB-specific S3 configuration
|
||||
const { getS3Client } = await import('@/lib/uploads/s3/s3-client')
|
||||
const { GetObjectCommand } = await import('@aws-sdk/client-s3')
|
||||
|
||||
const s3Client = getS3Client()
|
||||
const command = new GetObjectCommand({
|
||||
Bucket: S3_KB_CONFIG.bucket,
|
||||
Key: cloudKey,
|
||||
})
|
||||
|
||||
const response = await s3Client.send(command)
|
||||
if (!response.Body) {
|
||||
throw new Error('No body in S3 response')
|
||||
}
|
||||
|
||||
// Convert stream to buffer using the same method as the regular S3 client
|
||||
const stream = response.Body as any
|
||||
return new Promise<Buffer>((resolve, reject) => {
|
||||
const chunks: Buffer[] = []
|
||||
stream.on('data', (chunk: Buffer) => chunks.push(chunk))
|
||||
stream.on('end', () => resolve(Buffer.concat(chunks)))
|
||||
stream.on('error', reject)
|
||||
return downloadFile(cloudKey, {
|
||||
bucket: S3_KB_CONFIG.bucket,
|
||||
region: S3_KB_CONFIG.region,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -167,17 +128,22 @@ async function handleCloudProxy(
|
||||
if (isKBFile) {
|
||||
fileBuffer = await downloadKBFile(cloudKey)
|
||||
} else if (bucketType === 'copilot') {
|
||||
// Download from copilot-specific bucket
|
||||
const storageProvider = getStorageProvider()
|
||||
|
||||
if (storageProvider === 's3') {
|
||||
const { downloadFromS3WithConfig } = await import('@/lib/uploads/s3/s3-client')
|
||||
const { S3_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
|
||||
fileBuffer = await downloadFromS3WithConfig(cloudKey, S3_COPILOT_CONFIG)
|
||||
fileBuffer = await downloadFile(cloudKey, {
|
||||
bucket: S3_COPILOT_CONFIG.bucket,
|
||||
region: S3_COPILOT_CONFIG.region,
|
||||
})
|
||||
} else if (storageProvider === 'blob') {
|
||||
// For Azure Blob, use the default downloadFile for now
|
||||
// TODO: Add downloadFromBlobWithConfig when needed
|
||||
fileBuffer = await downloadFile(cloudKey)
|
||||
const { BLOB_COPILOT_CONFIG } = await import('@/lib/uploads/setup')
|
||||
fileBuffer = await downloadFile(cloudKey, {
|
||||
containerName: BLOB_COPILOT_CONFIG.containerName,
|
||||
accountName: BLOB_COPILOT_CONFIG.accountName,
|
||||
accountKey: BLOB_COPILOT_CONFIG.accountKey,
|
||||
connectionString: BLOB_COPILOT_CONFIG.connectionString,
|
||||
})
|
||||
} else {
|
||||
fileBuffer = await downloadFile(cloudKey)
|
||||
}
|
||||
|
||||
@@ -186,3 +186,190 @@ describe('File Upload API Route', () => {
|
||||
expect(response.headers.get('Access-Control-Allow-Headers')).toBe('Content-Type')
|
||||
})
|
||||
})
|
||||
|
||||
describe('File Upload Security Tests', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
vi.clearAllMocks()
|
||||
|
||||
vi.doMock('@/lib/auth', () => ({
|
||||
getSession: vi.fn().mockResolvedValue({
|
||||
user: { id: 'test-user-id' },
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/uploads', () => ({
|
||||
isUsingCloudStorage: vi.fn().mockReturnValue(false),
|
||||
uploadFile: vi.fn().mockResolvedValue({
|
||||
key: 'test-key',
|
||||
path: '/test/path',
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/uploads/setup.server', () => ({}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('File Extension Validation', () => {
|
||||
it('should accept allowed file types', async () => {
|
||||
const allowedTypes = [
|
||||
'pdf',
|
||||
'doc',
|
||||
'docx',
|
||||
'txt',
|
||||
'md',
|
||||
'png',
|
||||
'jpg',
|
||||
'jpeg',
|
||||
'gif',
|
||||
'csv',
|
||||
'xlsx',
|
||||
'xls',
|
||||
]
|
||||
|
||||
for (const ext of allowedTypes) {
|
||||
const formData = new FormData()
|
||||
const file = new File(['test content'], `test.${ext}`, { type: 'application/octet-stream' })
|
||||
formData.append('file', file)
|
||||
|
||||
const req = new Request('http://localhost/api/files/upload', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/files/upload/route')
|
||||
const response = await POST(req as any)
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
}
|
||||
})
|
||||
|
||||
it('should reject HTML files to prevent XSS', async () => {
|
||||
const formData = new FormData()
|
||||
const maliciousContent = '<script>alert("XSS")</script>'
|
||||
const file = new File([maliciousContent], 'malicious.html', { type: 'text/html' })
|
||||
formData.append('file', file)
|
||||
|
||||
const req = new Request('http://localhost/api/files/upload', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/files/upload/route')
|
||||
const response = await POST(req as any)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const data = await response.json()
|
||||
expect(data.message).toContain("File type 'html' is not allowed")
|
||||
})
|
||||
|
||||
it('should reject SVG files to prevent XSS', async () => {
|
||||
const formData = new FormData()
|
||||
const maliciousSvg = '<svg onload="alert(\'XSS\')" xmlns="http://www.w3.org/2000/svg"></svg>'
|
||||
const file = new File([maliciousSvg], 'malicious.svg', { type: 'image/svg+xml' })
|
||||
formData.append('file', file)
|
||||
|
||||
const req = new Request('http://localhost/api/files/upload', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/files/upload/route')
|
||||
const response = await POST(req as any)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const data = await response.json()
|
||||
expect(data.message).toContain("File type 'svg' is not allowed")
|
||||
})
|
||||
|
||||
it('should reject JavaScript files', async () => {
|
||||
const formData = new FormData()
|
||||
const maliciousJs = 'alert("XSS")'
|
||||
const file = new File([maliciousJs], 'malicious.js', { type: 'application/javascript' })
|
||||
formData.append('file', file)
|
||||
|
||||
const req = new Request('http://localhost/api/files/upload', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/files/upload/route')
|
||||
const response = await POST(req as any)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const data = await response.json()
|
||||
expect(data.message).toContain("File type 'js' is not allowed")
|
||||
})
|
||||
|
||||
it('should reject files without extensions', async () => {
|
||||
const formData = new FormData()
|
||||
const file = new File(['test content'], 'noextension', { type: 'application/octet-stream' })
|
||||
formData.append('file', file)
|
||||
|
||||
const req = new Request('http://localhost/api/files/upload', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/files/upload/route')
|
||||
const response = await POST(req as any)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const data = await response.json()
|
||||
expect(data.message).toContain("File type 'noextension' is not allowed")
|
||||
})
|
||||
|
||||
it('should handle multiple files with mixed valid/invalid types', async () => {
|
||||
const formData = new FormData()
|
||||
|
||||
// Valid file
|
||||
const validFile = new File(['valid content'], 'valid.pdf', { type: 'application/pdf' })
|
||||
formData.append('file', validFile)
|
||||
|
||||
// Invalid file (should cause rejection of entire request)
|
||||
const invalidFile = new File(['<script>alert("XSS")</script>'], 'malicious.html', {
|
||||
type: 'text/html',
|
||||
})
|
||||
formData.append('file', invalidFile)
|
||||
|
||||
const req = new Request('http://localhost/api/files/upload', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/files/upload/route')
|
||||
const response = await POST(req as any)
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
const data = await response.json()
|
||||
expect(data.message).toContain("File type 'html' is not allowed")
|
||||
})
|
||||
})
|
||||
|
||||
describe('Authentication Requirements', () => {
|
||||
it('should reject uploads without authentication', async () => {
|
||||
vi.doMock('@/lib/auth', () => ({
|
||||
getSession: vi.fn().mockResolvedValue(null),
|
||||
}))
|
||||
|
||||
const formData = new FormData()
|
||||
const file = new File(['test content'], 'test.pdf', { type: 'application/pdf' })
|
||||
formData.append('file', file)
|
||||
|
||||
const req = new Request('http://localhost/api/files/upload', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/files/upload/route')
|
||||
const response = await POST(req as any)
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
const data = await response.json()
|
||||
expect(data.error).toBe('Unauthorized')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -9,6 +9,34 @@ import {
|
||||
InvalidRequestError,
|
||||
} from '@/app/api/files/utils'
|
||||
|
||||
// Allowlist of permitted file extensions for security
|
||||
const ALLOWED_EXTENSIONS = new Set([
|
||||
// Documents
|
||||
'pdf',
|
||||
'doc',
|
||||
'docx',
|
||||
'txt',
|
||||
'md',
|
||||
// Images (safe formats)
|
||||
'png',
|
||||
'jpg',
|
||||
'jpeg',
|
||||
'gif',
|
||||
// Data files
|
||||
'csv',
|
||||
'xlsx',
|
||||
'xls',
|
||||
])
|
||||
|
||||
/**
|
||||
* Validates file extension against allowlist
|
||||
*/
|
||||
function validateFileExtension(filename: string): boolean {
|
||||
const extension = filename.split('.').pop()?.toLowerCase()
|
||||
if (!extension) return false
|
||||
return ALLOWED_EXTENSIONS.has(extension)
|
||||
}
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
const logger = createLogger('FilesUploadAPI')
|
||||
@@ -49,6 +77,14 @@ export async function POST(request: NextRequest) {
|
||||
// Process each file
|
||||
for (const file of files) {
|
||||
const originalName = file.name
|
||||
|
||||
if (!validateFileExtension(originalName)) {
|
||||
const extension = originalName.split('.').pop()?.toLowerCase() || 'unknown'
|
||||
throw new InvalidRequestError(
|
||||
`File type '${extension}' is not allowed. Allowed types: ${Array.from(ALLOWED_EXTENSIONS).join(', ')}`
|
||||
)
|
||||
}
|
||||
|
||||
const bytes = await file.arrayBuffer()
|
||||
const buffer = Buffer.from(bytes)
|
||||
|
||||
|
||||
327
apps/sim/app/api/files/utils.test.ts
Normal file
327
apps/sim/app/api/files/utils.test.ts
Normal file
@@ -0,0 +1,327 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
import { createFileResponse, extractFilename } from './utils'
|
||||
|
||||
describe('extractFilename', () => {
|
||||
describe('legitimate file paths', () => {
|
||||
it('should extract filename from standard serve path', () => {
|
||||
expect(extractFilename('/api/files/serve/test-file.txt')).toBe('test-file.txt')
|
||||
})
|
||||
|
||||
it('should extract filename from serve path with special characters', () => {
|
||||
expect(extractFilename('/api/files/serve/document-with-dashes_and_underscores.pdf')).toBe(
|
||||
'document-with-dashes_and_underscores.pdf'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle simple filename without serve path', () => {
|
||||
expect(extractFilename('simple-file.txt')).toBe('simple-file.txt')
|
||||
})
|
||||
|
||||
it('should extract last segment from nested path', () => {
|
||||
expect(extractFilename('nested/path/file.txt')).toBe('file.txt')
|
||||
})
|
||||
})
|
||||
|
||||
describe('cloud storage paths', () => {
|
||||
it('should preserve S3 path structure', () => {
|
||||
expect(extractFilename('/api/files/serve/s3/1234567890-test-file.txt')).toBe(
|
||||
's3/1234567890-test-file.txt'
|
||||
)
|
||||
})
|
||||
|
||||
it('should preserve S3 path with nested folders', () => {
|
||||
expect(extractFilename('/api/files/serve/s3/folder/subfolder/document.pdf')).toBe(
|
||||
's3/folder/subfolder/document.pdf'
|
||||
)
|
||||
})
|
||||
|
||||
it('should preserve Azure Blob path structure', () => {
|
||||
expect(extractFilename('/api/files/serve/blob/1234567890-test-document.pdf')).toBe(
|
||||
'blob/1234567890-test-document.pdf'
|
||||
)
|
||||
})
|
||||
|
||||
it('should preserve Blob path with nested folders', () => {
|
||||
expect(extractFilename('/api/files/serve/blob/uploads/user-files/report.xlsx')).toBe(
|
||||
'blob/uploads/user-files/report.xlsx'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('security - path traversal prevention', () => {
|
||||
it('should sanitize basic path traversal attempt', () => {
|
||||
expect(extractFilename('/api/files/serve/../config.txt')).toBe('config.txt')
|
||||
})
|
||||
|
||||
it('should sanitize deep path traversal attempt', () => {
|
||||
expect(extractFilename('/api/files/serve/../../../../../etc/passwd')).toBe('etcpasswd')
|
||||
})
|
||||
|
||||
it('should sanitize multiple path traversal patterns', () => {
|
||||
expect(extractFilename('/api/files/serve/../../secret.txt')).toBe('secret.txt')
|
||||
})
|
||||
|
||||
it('should sanitize path traversal with forward slashes', () => {
|
||||
expect(extractFilename('/api/files/serve/../../../system/file')).toBe('systemfile')
|
||||
})
|
||||
|
||||
it('should sanitize mixed path traversal patterns', () => {
|
||||
expect(extractFilename('/api/files/serve/../folder/../file.txt')).toBe('folderfile.txt')
|
||||
})
|
||||
|
||||
it('should remove directory separators from local filenames', () => {
|
||||
expect(extractFilename('/api/files/serve/folder/with/separators.txt')).toBe(
|
||||
'folderwithseparators.txt'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle backslash path separators (Windows style)', () => {
|
||||
expect(extractFilename('/api/files/serve/folder\\file.txt')).toBe('folderfile.txt')
|
||||
})
|
||||
})
|
||||
|
||||
describe('cloud storage path traversal prevention', () => {
|
||||
it('should sanitize S3 path traversal attempts while preserving structure', () => {
|
||||
expect(extractFilename('/api/files/serve/s3/../config')).toBe('s3/config')
|
||||
})
|
||||
|
||||
it('should sanitize S3 path with nested traversal attempts', () => {
|
||||
expect(extractFilename('/api/files/serve/s3/folder/../sensitive/../file.txt')).toBe(
|
||||
's3/folder/sensitive/file.txt'
|
||||
)
|
||||
})
|
||||
|
||||
it('should sanitize Blob path traversal attempts while preserving structure', () => {
|
||||
expect(extractFilename('/api/files/serve/blob/../system.txt')).toBe('blob/system.txt')
|
||||
})
|
||||
|
||||
it('should remove leading dots from cloud path segments', () => {
|
||||
expect(extractFilename('/api/files/serve/s3/.hidden/../file.txt')).toBe('s3/hidden/file.txt')
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases and error handling', () => {
|
||||
it('should handle filename with dots (but not traversal)', () => {
|
||||
expect(extractFilename('/api/files/serve/file.with.dots.txt')).toBe('file.with.dots.txt')
|
||||
})
|
||||
|
||||
it('should handle filename with multiple extensions', () => {
|
||||
expect(extractFilename('/api/files/serve/archive.tar.gz')).toBe('archive.tar.gz')
|
||||
})
|
||||
|
||||
it('should throw error for empty filename after sanitization', () => {
|
||||
expect(() => extractFilename('/api/files/serve/')).toThrow(
|
||||
'Invalid or empty filename after sanitization'
|
||||
)
|
||||
})
|
||||
|
||||
it('should throw error for filename that becomes empty after path traversal removal', () => {
|
||||
expect(() => extractFilename('/api/files/serve/../..')).toThrow(
|
||||
'Invalid or empty filename after sanitization'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle single character filenames', () => {
|
||||
expect(extractFilename('/api/files/serve/a')).toBe('a')
|
||||
})
|
||||
|
||||
it('should handle numeric filenames', () => {
|
||||
expect(extractFilename('/api/files/serve/123')).toBe('123')
|
||||
})
|
||||
})
|
||||
|
||||
describe('backward compatibility', () => {
|
||||
it('should match old behavior for legitimate local files', () => {
|
||||
// These test cases verify that our security fix maintains exact backward compatibility
|
||||
// for all legitimate use cases found in the existing codebase
|
||||
expect(extractFilename('/api/files/serve/test-file.txt')).toBe('test-file.txt')
|
||||
expect(extractFilename('/api/files/serve/nonexistent.txt')).toBe('nonexistent.txt')
|
||||
})
|
||||
|
||||
it('should match old behavior for legitimate cloud files', () => {
|
||||
// These test cases are from the actual delete route tests
|
||||
expect(extractFilename('/api/files/serve/s3/1234567890-test-file.txt')).toBe(
|
||||
's3/1234567890-test-file.txt'
|
||||
)
|
||||
expect(extractFilename('/api/files/serve/blob/1234567890-test-document.pdf')).toBe(
|
||||
'blob/1234567890-test-document.pdf'
|
||||
)
|
||||
})
|
||||
|
||||
it('should match old behavior for simple paths', () => {
|
||||
// These match the mock implementations in serve route tests
|
||||
expect(extractFilename('simple-file.txt')).toBe('simple-file.txt')
|
||||
expect(extractFilename('nested/path/file.txt')).toBe('file.txt')
|
||||
})
|
||||
})
|
||||
|
||||
describe('File Serving Security Tests', () => {
|
||||
describe('createFileResponse security headers', () => {
|
||||
it('should serve safe images inline with proper headers', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('fake-image-data'),
|
||||
contentType: 'image/png',
|
||||
filename: 'safe-image.png',
|
||||
})
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(response.headers.get('Content-Type')).toBe('image/png')
|
||||
expect(response.headers.get('Content-Disposition')).toBe(
|
||||
'inline; filename="safe-image.png"'
|
||||
)
|
||||
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
|
||||
expect(response.headers.get('Content-Security-Policy')).toBe(
|
||||
"default-src 'none'; style-src 'unsafe-inline'; sandbox;"
|
||||
)
|
||||
})
|
||||
|
||||
it('should serve PDFs inline safely', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('fake-pdf-data'),
|
||||
contentType: 'application/pdf',
|
||||
filename: 'document.pdf',
|
||||
})
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(response.headers.get('Content-Type')).toBe('application/pdf')
|
||||
expect(response.headers.get('Content-Disposition')).toBe('inline; filename="document.pdf"')
|
||||
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
|
||||
})
|
||||
|
||||
it('should force attachment for HTML files to prevent XSS', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('<script>alert("XSS")</script>'),
|
||||
contentType: 'text/html',
|
||||
filename: 'malicious.html',
|
||||
})
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
|
||||
expect(response.headers.get('Content-Disposition')).toBe(
|
||||
'attachment; filename="malicious.html"'
|
||||
)
|
||||
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
|
||||
})
|
||||
|
||||
it('should force attachment for SVG files to prevent XSS', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from(
|
||||
'<svg onload="alert(\'XSS\')" xmlns="http://www.w3.org/2000/svg"></svg>'
|
||||
),
|
||||
contentType: 'image/svg+xml',
|
||||
filename: 'malicious.svg',
|
||||
})
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
|
||||
expect(response.headers.get('Content-Disposition')).toBe(
|
||||
'attachment; filename="malicious.svg"'
|
||||
)
|
||||
})
|
||||
|
||||
it('should override dangerous content types to safe alternatives', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('<svg>safe content</svg>'),
|
||||
contentType: 'image/svg+xml',
|
||||
filename: 'image.png', // Extension doesn't match content-type
|
||||
})
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
// Should override SVG content type to plain text for safety
|
||||
expect(response.headers.get('Content-Type')).toBe('text/plain')
|
||||
expect(response.headers.get('Content-Disposition')).toBe('inline; filename="image.png"')
|
||||
})
|
||||
|
||||
it('should force attachment for JavaScript files', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('alert("XSS")'),
|
||||
contentType: 'application/javascript',
|
||||
filename: 'malicious.js',
|
||||
})
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
|
||||
expect(response.headers.get('Content-Disposition')).toBe(
|
||||
'attachment; filename="malicious.js"'
|
||||
)
|
||||
})
|
||||
|
||||
it('should force attachment for CSS files', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('body { background: url(javascript:alert("XSS")) }'),
|
||||
contentType: 'text/css',
|
||||
filename: 'malicious.css',
|
||||
})
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
|
||||
expect(response.headers.get('Content-Disposition')).toBe(
|
||||
'attachment; filename="malicious.css"'
|
||||
)
|
||||
})
|
||||
|
||||
it('should force attachment for XML files', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('<?xml version="1.0"?><root><script>alert("XSS")</script></root>'),
|
||||
contentType: 'application/xml',
|
||||
filename: 'malicious.xml',
|
||||
})
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(response.headers.get('Content-Type')).toBe('application/octet-stream')
|
||||
expect(response.headers.get('Content-Disposition')).toBe(
|
||||
'attachment; filename="malicious.xml"'
|
||||
)
|
||||
})
|
||||
|
||||
it('should serve text files safely', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('Safe text content'),
|
||||
contentType: 'text/plain',
|
||||
filename: 'document.txt',
|
||||
})
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(response.headers.get('Content-Type')).toBe('text/plain')
|
||||
expect(response.headers.get('Content-Disposition')).toBe('inline; filename="document.txt"')
|
||||
})
|
||||
|
||||
it('should force attachment for unknown/unsafe content types', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('unknown content'),
|
||||
contentType: 'application/unknown',
|
||||
filename: 'unknown.bin',
|
||||
})
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(response.headers.get('Content-Type')).toBe('application/unknown')
|
||||
expect(response.headers.get('Content-Disposition')).toBe(
|
||||
'attachment; filename="unknown.bin"'
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Content Security Policy', () => {
|
||||
it('should include CSP header in all responses', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('test'),
|
||||
contentType: 'text/plain',
|
||||
filename: 'test.txt',
|
||||
})
|
||||
|
||||
const csp = response.headers.get('Content-Security-Policy')
|
||||
expect(csp).toBe("default-src 'none'; style-src 'unsafe-inline'; sandbox;")
|
||||
})
|
||||
|
||||
it('should include X-Content-Type-Options header', () => {
|
||||
const response = createFileResponse({
|
||||
buffer: Buffer.from('test'),
|
||||
contentType: 'text/plain',
|
||||
filename: 'test.txt',
|
||||
})
|
||||
|
||||
expect(response.headers.get('X-Content-Type-Options')).toBe('nosniff')
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -70,7 +70,6 @@ export const contentTypeMap: Record<string, string> = {
|
||||
jpg: 'image/jpeg',
|
||||
jpeg: 'image/jpeg',
|
||||
gif: 'image/gif',
|
||||
svg: 'image/svg+xml',
|
||||
// Archive formats
|
||||
zip: 'application/zip',
|
||||
// Folder format
|
||||
@@ -153,10 +152,43 @@ export function extractBlobKey(path: string): string {
|
||||
* Extract filename from a serve path
|
||||
*/
|
||||
export function extractFilename(path: string): string {
|
||||
let filename: string
|
||||
|
||||
if (path.startsWith('/api/files/serve/')) {
|
||||
return path.substring('/api/files/serve/'.length)
|
||||
filename = path.substring('/api/files/serve/'.length)
|
||||
} else {
|
||||
filename = path.split('/').pop() || path
|
||||
}
|
||||
return path.split('/').pop() || path
|
||||
|
||||
filename = filename
|
||||
.replace(/\.\./g, '')
|
||||
.replace(/\/\.\./g, '')
|
||||
.replace(/\.\.\//g, '')
|
||||
|
||||
// Handle cloud storage paths (s3/key, blob/key) - preserve forward slashes for these
|
||||
if (filename.startsWith('s3/') || filename.startsWith('blob/')) {
|
||||
// For cloud paths, only sanitize the key portion after the prefix
|
||||
const parts = filename.split('/')
|
||||
const prefix = parts[0] // 's3' or 'blob'
|
||||
const keyParts = parts.slice(1)
|
||||
|
||||
// Sanitize each part of the key to prevent traversal
|
||||
const sanitizedKeyParts = keyParts
|
||||
.map((part) => part.replace(/\.\./g, '').replace(/^\./g, '').trim())
|
||||
.filter((part) => part.length > 0)
|
||||
|
||||
filename = `${prefix}/${sanitizedKeyParts.join('/')}`
|
||||
} else {
|
||||
// For regular filenames, remove any remaining path separators
|
||||
filename = filename.replace(/[/\\]/g, '')
|
||||
}
|
||||
|
||||
// Additional validation: ensure filename is not empty after sanitization
|
||||
if (!filename || filename.trim().length === 0) {
|
||||
throw new Error('Invalid or empty filename after sanitization')
|
||||
}
|
||||
|
||||
return filename
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -174,16 +206,65 @@ export function findLocalFile(filename: string): string | null {
|
||||
return null
|
||||
}
|
||||
|
||||
const SAFE_INLINE_TYPES = new Set([
|
||||
'image/png',
|
||||
'image/jpeg',
|
||||
'image/jpg',
|
||||
'image/gif',
|
||||
'application/pdf',
|
||||
'text/plain',
|
||||
'text/csv',
|
||||
'application/json',
|
||||
])
|
||||
|
||||
// File extensions that should always be served as attachment for security
|
||||
const FORCE_ATTACHMENT_EXTENSIONS = new Set(['html', 'htm', 'svg', 'js', 'css', 'xml'])
|
||||
|
||||
/**
|
||||
* Create a file response with appropriate headers
|
||||
* Determines safe content type and disposition for file serving
|
||||
*/
|
||||
function getSecureFileHeaders(filename: string, originalContentType: string) {
|
||||
const extension = filename.split('.').pop()?.toLowerCase() || ''
|
||||
|
||||
// Force attachment for potentially dangerous file types
|
||||
if (FORCE_ATTACHMENT_EXTENSIONS.has(extension)) {
|
||||
return {
|
||||
contentType: 'application/octet-stream', // Force download
|
||||
disposition: 'attachment',
|
||||
}
|
||||
}
|
||||
|
||||
// Override content type for safety while preserving legitimate use cases
|
||||
let safeContentType = originalContentType
|
||||
|
||||
// Handle potentially dangerous content types
|
||||
if (originalContentType === 'text/html' || originalContentType === 'image/svg+xml') {
|
||||
safeContentType = 'text/plain' // Prevent browser rendering
|
||||
}
|
||||
|
||||
// Use inline only for verified safe content types
|
||||
const disposition = SAFE_INLINE_TYPES.has(safeContentType) ? 'inline' : 'attachment'
|
||||
|
||||
return {
|
||||
contentType: safeContentType,
|
||||
disposition,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a file response with appropriate security headers
|
||||
*/
|
||||
export function createFileResponse(file: FileResponse): NextResponse {
|
||||
return new NextResponse(file.buffer, {
|
||||
const { contentType, disposition } = getSecureFileHeaders(file.filename, file.contentType)
|
||||
|
||||
return new NextResponse(file.buffer as BodyInit, {
|
||||
status: 200,
|
||||
headers: {
|
||||
'Content-Type': file.contentType,
|
||||
'Content-Disposition': `inline; filename="${file.filename}"`,
|
||||
'Content-Type': contentType,
|
||||
'Content-Disposition': `${disposition}; filename="${file.filename}"`,
|
||||
'Cache-Control': 'public, max-age=31536000', // Cache for 1 year
|
||||
'X-Content-Type-Options': 'nosniff',
|
||||
'Content-Security-Policy': "default-src 'none'; style-src 'unsafe-inline'; sandbox;",
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,7 +7,6 @@ import { NextRequest } from 'next/server'
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { createMockRequest } from '@/app/api/__test-utils__/utils'
|
||||
|
||||
const mockFreestyleExecuteScript = vi.fn()
|
||||
const mockCreateContext = vi.fn()
|
||||
const mockRunInContext = vi.fn()
|
||||
const mockLogger = {
|
||||
@@ -29,26 +28,17 @@ describe('Function Execute API Route', () => {
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.doMock('freestyle-sandboxes', () => ({
|
||||
FreestyleSandboxes: vi.fn().mockImplementation(() => ({
|
||||
executeScript: mockFreestyleExecuteScript,
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
FREESTYLE_API_KEY: 'test-freestyle-key',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/logs/console/logger', () => ({
|
||||
createLogger: vi.fn().mockReturnValue(mockLogger),
|
||||
}))
|
||||
|
||||
mockFreestyleExecuteScript.mockResolvedValue({
|
||||
result: 'freestyle success',
|
||||
logs: [],
|
||||
})
|
||||
vi.doMock('@/lib/execution/e2b', () => ({
|
||||
executeInE2B: vi.fn().mockResolvedValue({
|
||||
result: 'e2b success',
|
||||
stdout: 'e2b output',
|
||||
sandboxId: 'test-sandbox-id',
|
||||
}),
|
||||
}))
|
||||
|
||||
mockRunInContext.mockResolvedValue('vm success')
|
||||
mockCreateContext.mockReturnValue({})
|
||||
@@ -63,6 +53,7 @@ describe('Function Execute API Route', () => {
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return "Hello World"',
|
||||
timeout: 5000,
|
||||
useLocalVM: true,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/function/execute/route')
|
||||
@@ -92,6 +83,7 @@ describe('Function Execute API Route', () => {
|
||||
it('should use default timeout when not provided', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return "test"',
|
||||
useLocalVM: true,
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/function/execute/route')
|
||||
@@ -111,6 +103,7 @@ describe('Function Execute API Route', () => {
|
||||
it('should resolve environment variables with {{var_name}} syntax', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return {{API_KEY}}',
|
||||
useLocalVM: true,
|
||||
envVars: {
|
||||
API_KEY: 'secret-key-123',
|
||||
},
|
||||
@@ -126,6 +119,7 @@ describe('Function Execute API Route', () => {
|
||||
it('should resolve tag variables with <tag_name> syntax', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return <email>',
|
||||
useLocalVM: true,
|
||||
params: {
|
||||
email: { id: '123', subject: 'Test Email' },
|
||||
},
|
||||
@@ -141,6 +135,7 @@ describe('Function Execute API Route', () => {
|
||||
it('should NOT treat email addresses as template variables', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return "Email sent to user"',
|
||||
useLocalVM: true,
|
||||
params: {
|
||||
email: {
|
||||
from: 'Waleed Latif <waleed@sim.ai>',
|
||||
@@ -159,6 +154,7 @@ describe('Function Execute API Route', () => {
|
||||
it('should only match valid variable names in angle brackets', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return <validVar> + "<invalid@email.com>" + <another_valid>',
|
||||
useLocalVM: true,
|
||||
params: {
|
||||
validVar: 'hello',
|
||||
another_valid: 'world',
|
||||
@@ -196,6 +192,7 @@ describe('Function Execute API Route', () => {
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return <email>',
|
||||
useLocalVM: true,
|
||||
params: gmailData,
|
||||
})
|
||||
|
||||
@@ -218,6 +215,7 @@ describe('Function Execute API Route', () => {
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return <email>',
|
||||
useLocalVM: true,
|
||||
params: complexEmailData,
|
||||
})
|
||||
|
||||
@@ -228,111 +226,11 @@ describe('Function Execute API Route', () => {
|
||||
})
|
||||
})
|
||||
|
||||
describe.skip('Freestyle Execution', () => {
|
||||
it('should use Freestyle when API key is available', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return "freestyle test"',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/function/execute/route')
|
||||
await POST(req)
|
||||
|
||||
expect(mockFreestyleExecuteScript).toHaveBeenCalled()
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringMatching(/\[.*\] Using Freestyle for code execution/)
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle Freestyle errors and fallback to VM', async () => {
|
||||
mockFreestyleExecuteScript.mockRejectedValueOnce(new Error('Freestyle API error'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return "fallback test"',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/function/execute/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(mockFreestyleExecuteScript).toHaveBeenCalled()
|
||||
expect(mockRunInContext).toHaveBeenCalled()
|
||||
expect(mockLogger.error).toHaveBeenCalledWith(
|
||||
expect.stringMatching(/\[.*\] Freestyle API call failed, falling back to VM:/),
|
||||
expect.any(Object)
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle Freestyle script errors', async () => {
|
||||
mockFreestyleExecuteScript.mockResolvedValueOnce({
|
||||
result: null,
|
||||
logs: [{ type: 'error', message: 'ReferenceError: undefined variable' }],
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return undefinedVariable',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/function/execute/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const data = await response.json()
|
||||
expect(data.success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('VM Execution', () => {
|
||||
it.skip('should use VM when Freestyle API key is not available', async () => {
|
||||
// Mock no Freestyle API key
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
FREESTYLE_API_KEY: undefined,
|
||||
},
|
||||
}))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return "vm test"',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/function/execute/route')
|
||||
await POST(req)
|
||||
|
||||
expect(mockFreestyleExecuteScript).not.toHaveBeenCalled()
|
||||
expect(mockRunInContext).toHaveBeenCalled()
|
||||
expect(mockLogger.info).toHaveBeenCalledWith(
|
||||
expect.stringMatching(
|
||||
/\[.*\] Using VM for code execution \(no Freestyle API key available\)/
|
||||
)
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle VM execution errors', async () => {
|
||||
// Mock no Freestyle API key so it uses VM
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
FREESTYLE_API_KEY: undefined,
|
||||
},
|
||||
}))
|
||||
|
||||
mockRunInContext.mockRejectedValueOnce(new Error('VM execution error'))
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return invalidCode(',
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/function/execute/route')
|
||||
const response = await POST(req)
|
||||
|
||||
expect(response.status).toBe(500)
|
||||
const data = await response.json()
|
||||
expect(data.success).toBe(false)
|
||||
expect(data.error).toContain('VM execution error')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Custom Tools', () => {
|
||||
it('should handle custom tool execution with direct parameter access', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return location + " weather is sunny"',
|
||||
useLocalVM: true,
|
||||
params: {
|
||||
location: 'San Francisco',
|
||||
},
|
||||
@@ -364,6 +262,7 @@ describe('Function Execute API Route', () => {
|
||||
it('should handle timeout parameter', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return "test"',
|
||||
useLocalVM: true,
|
||||
timeout: 10000,
|
||||
})
|
||||
|
||||
@@ -381,6 +280,7 @@ describe('Function Execute API Route', () => {
|
||||
it('should handle empty parameters object', async () => {
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return "no params"',
|
||||
useLocalVM: true,
|
||||
params: {},
|
||||
})
|
||||
|
||||
@@ -414,6 +314,7 @@ SyntaxError: Invalid or unexpected token
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'const obj = {\n name: "test",\n description: "This has a missing closing quote\n};\nreturn obj;',
|
||||
useLocalVM: true,
|
||||
timeout: 5000,
|
||||
})
|
||||
|
||||
@@ -457,6 +358,7 @@ SyntaxError: Invalid or unexpected token
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'const obj = null;\nreturn obj.someMethod();',
|
||||
useLocalVM: true,
|
||||
timeout: 5000,
|
||||
})
|
||||
|
||||
@@ -498,6 +400,7 @@ SyntaxError: Invalid or unexpected token
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'const x = 42;\nreturn undefinedVariable + x;',
|
||||
useLocalVM: true,
|
||||
timeout: 5000,
|
||||
})
|
||||
|
||||
@@ -528,6 +431,7 @@ SyntaxError: Invalid or unexpected token
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return "test";',
|
||||
useLocalVM: true,
|
||||
timeout: 5000,
|
||||
})
|
||||
|
||||
@@ -564,6 +468,7 @@ SyntaxError: Invalid or unexpected token
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'const a = 1;\nconst b = 2;\nconst c = 3;\nconst d = 4;\nreturn a + b + c + d;',
|
||||
useLocalVM: true,
|
||||
timeout: 5000,
|
||||
})
|
||||
|
||||
@@ -595,6 +500,7 @@ SyntaxError: Invalid or unexpected token
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'const obj = {\n name: "test"\n// Missing closing brace',
|
||||
useLocalVM: true,
|
||||
timeout: 5000,
|
||||
})
|
||||
|
||||
@@ -615,6 +521,7 @@ SyntaxError: Invalid or unexpected token
|
||||
// This tests the escapeRegExp function indirectly
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return {{special.chars+*?}}',
|
||||
useLocalVM: true,
|
||||
envVars: {
|
||||
'special.chars+*?': 'escaped-value',
|
||||
},
|
||||
@@ -631,6 +538,7 @@ SyntaxError: Invalid or unexpected token
|
||||
// Test with complex but not circular data first
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return <complexData>',
|
||||
useLocalVM: true,
|
||||
params: {
|
||||
complexData: {
|
||||
special: 'chars"with\'quotes',
|
||||
@@ -651,113 +559,3 @@ SyntaxError: Invalid or unexpected token
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Function Execute API - Template Variable Edge Cases', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetModules()
|
||||
vi.resetAllMocks()
|
||||
|
||||
vi.doMock('@/lib/logs/console/logger', () => ({
|
||||
createLogger: vi.fn().mockReturnValue(mockLogger),
|
||||
}))
|
||||
|
||||
vi.doMock('@/lib/env', () => ({
|
||||
env: {
|
||||
FREESTYLE_API_KEY: 'test-freestyle-key',
|
||||
},
|
||||
}))
|
||||
|
||||
vi.doMock('vm', () => ({
|
||||
createContext: mockCreateContext,
|
||||
Script: vi.fn().mockImplementation(() => ({
|
||||
runInContext: mockRunInContext,
|
||||
})),
|
||||
}))
|
||||
|
||||
vi.doMock('freestyle-sandboxes', () => ({
|
||||
FreestyleSandboxes: vi.fn().mockImplementation(() => ({
|
||||
executeScript: mockFreestyleExecuteScript,
|
||||
})),
|
||||
}))
|
||||
|
||||
mockFreestyleExecuteScript.mockResolvedValue({
|
||||
result: 'freestyle success',
|
||||
logs: [],
|
||||
})
|
||||
|
||||
mockRunInContext.mockResolvedValue('vm success')
|
||||
mockCreateContext.mockReturnValue({})
|
||||
})
|
||||
|
||||
it.skip('should handle nested template variables', async () => {
|
||||
mockFreestyleExecuteScript.mockResolvedValueOnce({
|
||||
result: 'environment-valueparam-value',
|
||||
logs: [],
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return {{outer}} + <inner>',
|
||||
envVars: {
|
||||
outer: 'environment-value',
|
||||
},
|
||||
params: {
|
||||
inner: 'param-value',
|
||||
},
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/function/execute/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.output.result).toBe('environment-valueparam-value')
|
||||
})
|
||||
|
||||
it.skip('should prioritize environment variables over params for {{}} syntax', async () => {
|
||||
mockFreestyleExecuteScript.mockResolvedValueOnce({
|
||||
result: 'env-wins',
|
||||
logs: [],
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return {{conflictVar}}',
|
||||
envVars: {
|
||||
conflictVar: 'env-wins',
|
||||
},
|
||||
params: {
|
||||
conflictVar: 'param-loses',
|
||||
},
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/function/execute/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
// Environment variable should take precedence
|
||||
expect(data.output.result).toBe('env-wins')
|
||||
})
|
||||
|
||||
it.skip('should handle missing template variables gracefully', async () => {
|
||||
mockFreestyleExecuteScript.mockResolvedValueOnce({
|
||||
result: '',
|
||||
logs: [],
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', {
|
||||
code: 'return {{nonexistent}} + <alsoMissing>',
|
||||
envVars: {},
|
||||
params: {},
|
||||
})
|
||||
|
||||
const { POST } = await import('@/app/api/function/execute/route')
|
||||
const response = await POST(req)
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.output.result).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import { createContext, Script } from 'vm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { env, isTruthy } from '@/lib/env'
|
||||
import { executeInE2B } from '@/lib/execution/e2b'
|
||||
import { CodeLanguage, DEFAULT_CODE_LANGUAGE, isValidCodeLanguage } from '@/lib/execution/languages'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
@@ -8,6 +11,10 @@ export const maxDuration = 60
|
||||
|
||||
const logger = createLogger('FunctionExecuteAPI')
|
||||
|
||||
// Constants for E2B code wrapping line counts
|
||||
const E2B_JS_WRAPPER_LINES = 3 // Lines before user code: ';(async () => {', ' try {', ' const __sim_result = await (async () => {'
|
||||
const E2B_PYTHON_WRAPPER_LINES = 1 // Lines before user code: 'def __sim_main__():'
|
||||
|
||||
/**
|
||||
* Enhanced error information interface
|
||||
*/
|
||||
@@ -124,6 +131,103 @@ function extractEnhancedError(
|
||||
return enhanced
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse and format E2B error message
|
||||
* Removes E2B-specific line references and adds correct user line numbers
|
||||
*/
|
||||
function formatE2BError(
|
||||
errorMessage: string,
|
||||
errorOutput: string,
|
||||
language: CodeLanguage,
|
||||
userCode: string,
|
||||
prologueLineCount: number
|
||||
): { formattedError: string; cleanedOutput: string } {
|
||||
// Calculate line offset based on language and prologue
|
||||
const wrapperLines =
|
||||
language === CodeLanguage.Python ? E2B_PYTHON_WRAPPER_LINES : E2B_JS_WRAPPER_LINES
|
||||
const totalOffset = prologueLineCount + wrapperLines
|
||||
|
||||
let userLine: number | undefined
|
||||
let cleanErrorType = ''
|
||||
let cleanErrorMsg = ''
|
||||
|
||||
if (language === CodeLanguage.Python) {
|
||||
// Python error format: "Cell In[X], line Y" followed by error details
|
||||
// Extract line number from the Cell reference
|
||||
const cellMatch = errorOutput.match(/Cell In\[\d+\], line (\d+)/)
|
||||
if (cellMatch) {
|
||||
const originalLine = Number.parseInt(cellMatch[1], 10)
|
||||
userLine = originalLine - totalOffset
|
||||
}
|
||||
|
||||
// Extract clean error message from the error string
|
||||
// Remove file references like "(detected at line X) (file.py, line Y)"
|
||||
cleanErrorMsg = errorMessage
|
||||
.replace(/\s*\(detected at line \d+\)/g, '')
|
||||
.replace(/\s*\([^)]+\.py, line \d+\)/g, '')
|
||||
.trim()
|
||||
} else if (language === CodeLanguage.JavaScript) {
|
||||
// JavaScript error format from E2B: "SyntaxError: /path/file.ts: Message. (line:col)\n\n 9 | ..."
|
||||
// First, extract the error type and message from the first line
|
||||
const firstLineEnd = errorMessage.indexOf('\n')
|
||||
const firstLine = firstLineEnd > 0 ? errorMessage.substring(0, firstLineEnd) : errorMessage
|
||||
|
||||
// Parse: "SyntaxError: /home/user/index.ts: Missing semicolon. (11:9)"
|
||||
const jsErrorMatch = firstLine.match(/^(\w+Error):\s*[^:]+:\s*([^(]+)\.\s*\((\d+):(\d+)\)/)
|
||||
if (jsErrorMatch) {
|
||||
cleanErrorType = jsErrorMatch[1]
|
||||
cleanErrorMsg = jsErrorMatch[2].trim()
|
||||
const originalLine = Number.parseInt(jsErrorMatch[3], 10)
|
||||
userLine = originalLine - totalOffset
|
||||
} else {
|
||||
// Fallback: look for line number in the arrow pointer line (> 11 |)
|
||||
const arrowMatch = errorMessage.match(/^>\s*(\d+)\s*\|/m)
|
||||
if (arrowMatch) {
|
||||
const originalLine = Number.parseInt(arrowMatch[1], 10)
|
||||
userLine = originalLine - totalOffset
|
||||
}
|
||||
// Try to extract error type and message
|
||||
const errorMatch = firstLine.match(/^(\w+Error):\s*(.+)/)
|
||||
if (errorMatch) {
|
||||
cleanErrorType = errorMatch[1]
|
||||
cleanErrorMsg = errorMatch[2]
|
||||
.replace(/^[^:]+:\s*/, '') // Remove file path
|
||||
.replace(/\s*\(\d+:\d+\)\s*$/, '') // Remove line:col at end
|
||||
.trim()
|
||||
} else {
|
||||
cleanErrorMsg = firstLine
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build the final clean error message
|
||||
const finalErrorMsg =
|
||||
cleanErrorType && cleanErrorMsg
|
||||
? `${cleanErrorType}: ${cleanErrorMsg}`
|
||||
: cleanErrorMsg || errorMessage
|
||||
|
||||
// Format with line number if available
|
||||
let formattedError = finalErrorMsg
|
||||
if (userLine && userLine > 0) {
|
||||
const codeLines = userCode.split('\n')
|
||||
// Clamp userLine to the actual user code range
|
||||
const actualUserLine = Math.min(userLine, codeLines.length)
|
||||
if (actualUserLine > 0 && actualUserLine <= codeLines.length) {
|
||||
const lineContent = codeLines[actualUserLine - 1]?.trim()
|
||||
if (lineContent) {
|
||||
formattedError = `Line ${actualUserLine}: \`${lineContent}\` - ${finalErrorMsg}`
|
||||
} else {
|
||||
formattedError = `Line ${actualUserLine} - ${finalErrorMsg}`
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For stdout, just return the clean error message without the full traceback
|
||||
const cleanedOutput = finalErrorMsg
|
||||
|
||||
return { formattedError, cleanedOutput }
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a detailed error message for users
|
||||
*/
|
||||
@@ -213,24 +317,81 @@ function createUserFriendlyErrorMessage(
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves environment variables and tags in code
|
||||
* @param code - Code with variables
|
||||
* @param params - Parameters that may contain variable values
|
||||
* @param envVars - Environment variables from the workflow
|
||||
* @returns Resolved code
|
||||
* Resolves workflow variables with <variable.name> syntax
|
||||
*/
|
||||
function resolveWorkflowVariables(
|
||||
code: string,
|
||||
workflowVariables: Record<string, any>,
|
||||
contextVariables: Record<string, any>
|
||||
): string {
|
||||
let resolvedCode = code
|
||||
|
||||
function resolveCodeVariables(
|
||||
const variableMatches = resolvedCode.match(/<variable\.([^>]+)>/g) || []
|
||||
for (const match of variableMatches) {
|
||||
const variableName = match.slice('<variable.'.length, -1).trim()
|
||||
|
||||
// Find the variable by name (workflowVariables is indexed by ID, values are variable objects)
|
||||
const foundVariable = Object.entries(workflowVariables).find(
|
||||
([_, variable]) => (variable.name || '').replace(/\s+/g, '') === variableName
|
||||
)
|
||||
|
||||
if (foundVariable) {
|
||||
const variable = foundVariable[1]
|
||||
// Get the typed value - handle different variable types
|
||||
let variableValue = variable.value
|
||||
|
||||
if (variable.value !== undefined && variable.value !== null) {
|
||||
try {
|
||||
// Handle 'string' type the same as 'plain' for backward compatibility
|
||||
const type = variable.type === 'string' ? 'plain' : variable.type
|
||||
|
||||
// For plain text, use exactly what's entered without modifications
|
||||
if (type === 'plain' && typeof variableValue === 'string') {
|
||||
// Use as-is for plain text
|
||||
} else if (type === 'number') {
|
||||
variableValue = Number(variableValue)
|
||||
} else if (type === 'boolean') {
|
||||
variableValue = variableValue === 'true' || variableValue === true
|
||||
} else if (type === 'json') {
|
||||
try {
|
||||
variableValue =
|
||||
typeof variableValue === 'string' ? JSON.parse(variableValue) : variableValue
|
||||
} catch {
|
||||
// Keep original value if JSON parsing fails
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Fallback to original value on error
|
||||
variableValue = variable.value
|
||||
}
|
||||
}
|
||||
|
||||
// Create a safe variable reference
|
||||
const safeVarName = `__variable_${variableName.replace(/[^a-zA-Z0-9_]/g, '_')}`
|
||||
contextVariables[safeVarName] = variableValue
|
||||
|
||||
// Replace the variable reference with the safe variable name
|
||||
resolvedCode = resolvedCode.replace(new RegExp(escapeRegExp(match), 'g'), safeVarName)
|
||||
} else {
|
||||
// Variable not found - replace with empty string to avoid syntax errors
|
||||
resolvedCode = resolvedCode.replace(new RegExp(escapeRegExp(match), 'g'), '')
|
||||
}
|
||||
}
|
||||
|
||||
return resolvedCode
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves environment variables with {{var_name}} syntax
|
||||
*/
|
||||
function resolveEnvironmentVariables(
|
||||
code: string,
|
||||
params: Record<string, any>,
|
||||
envVars: Record<string, string> = {},
|
||||
blockData: Record<string, any> = {},
|
||||
blockNameMapping: Record<string, string> = {}
|
||||
): { resolvedCode: string; contextVariables: Record<string, any> } {
|
||||
envVars: Record<string, string>,
|
||||
contextVariables: Record<string, any>
|
||||
): string {
|
||||
let resolvedCode = code
|
||||
const contextVariables: Record<string, any> = {}
|
||||
|
||||
// Resolve environment variables with {{var_name}} syntax
|
||||
const envVarMatches = resolvedCode.match(/\{\{([^}]+)\}\}/g) || []
|
||||
for (const match of envVarMatches) {
|
||||
const varName = match.slice(2, -2).trim()
|
||||
@@ -245,7 +406,21 @@ function resolveCodeVariables(
|
||||
resolvedCode = resolvedCode.replace(new RegExp(escapeRegExp(match), 'g'), safeVarName)
|
||||
}
|
||||
|
||||
// Resolve tags with <tag_name> syntax (including nested paths like <block.response.data>)
|
||||
return resolvedCode
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves tags with <tag_name> syntax (including nested paths like <block.response.data>)
|
||||
*/
|
||||
function resolveTagVariables(
|
||||
code: string,
|
||||
params: Record<string, any>,
|
||||
blockData: Record<string, any>,
|
||||
blockNameMapping: Record<string, string>,
|
||||
contextVariables: Record<string, any>
|
||||
): string {
|
||||
let resolvedCode = code
|
||||
|
||||
const tagMatches = resolvedCode.match(/<([a-zA-Z_][a-zA-Z0-9_.]*[a-zA-Z0-9_])>/g) || []
|
||||
|
||||
for (const match of tagMatches) {
|
||||
@@ -300,6 +475,42 @@ function resolveCodeVariables(
|
||||
resolvedCode = resolvedCode.replace(new RegExp(escapeRegExp(match), 'g'), safeVarName)
|
||||
}
|
||||
|
||||
return resolvedCode
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves environment variables and tags in code
|
||||
* @param code - Code with variables
|
||||
* @param params - Parameters that may contain variable values
|
||||
* @param envVars - Environment variables from the workflow
|
||||
* @returns Resolved code
|
||||
*/
|
||||
function resolveCodeVariables(
|
||||
code: string,
|
||||
params: Record<string, any>,
|
||||
envVars: Record<string, string> = {},
|
||||
blockData: Record<string, any> = {},
|
||||
blockNameMapping: Record<string, string> = {},
|
||||
workflowVariables: Record<string, any> = {}
|
||||
): { resolvedCode: string; contextVariables: Record<string, any> } {
|
||||
let resolvedCode = code
|
||||
const contextVariables: Record<string, any> = {}
|
||||
|
||||
// Resolve workflow variables with <variable.name> syntax first
|
||||
resolvedCode = resolveWorkflowVariables(resolvedCode, workflowVariables, contextVariables)
|
||||
|
||||
// Resolve environment variables with {{var_name}} syntax
|
||||
resolvedCode = resolveEnvironmentVariables(resolvedCode, params, envVars, contextVariables)
|
||||
|
||||
// Resolve tags with <tag_name> syntax (including nested paths like <block.response.data>)
|
||||
resolvedCode = resolveTagVariables(
|
||||
resolvedCode,
|
||||
params,
|
||||
blockData,
|
||||
blockNameMapping,
|
||||
contextVariables
|
||||
)
|
||||
|
||||
return { resolvedCode, contextVariables }
|
||||
}
|
||||
|
||||
@@ -335,9 +546,12 @@ export async function POST(req: NextRequest) {
|
||||
code,
|
||||
params = {},
|
||||
timeout = 5000,
|
||||
language = DEFAULT_CODE_LANGUAGE,
|
||||
useLocalVM = false,
|
||||
envVars = {},
|
||||
blockData = {},
|
||||
blockNameMapping = {},
|
||||
workflowVariables = {},
|
||||
workflowId,
|
||||
isCustomTool = false,
|
||||
} = body
|
||||
@@ -360,168 +574,170 @@ export async function POST(req: NextRequest) {
|
||||
executionParams,
|
||||
envVars,
|
||||
blockData,
|
||||
blockNameMapping
|
||||
blockNameMapping,
|
||||
workflowVariables
|
||||
)
|
||||
resolvedCode = codeResolution.resolvedCode
|
||||
const contextVariables = codeResolution.contextVariables
|
||||
|
||||
const executionMethod = 'vm' // Default execution method
|
||||
const e2bEnabled = isTruthy(env.E2B_ENABLED)
|
||||
const lang = isValidCodeLanguage(language) ? language : DEFAULT_CODE_LANGUAGE
|
||||
const useE2B =
|
||||
e2bEnabled &&
|
||||
!useLocalVM &&
|
||||
!isCustomTool &&
|
||||
(lang === CodeLanguage.JavaScript || lang === CodeLanguage.Python)
|
||||
|
||||
// // Try to use Freestyle if the API key is available
|
||||
// if (env.FREESTYLE_API_KEY) {
|
||||
// try {
|
||||
// logger.info(`[${requestId}] Using Freestyle for code execution`)
|
||||
// executionMethod = 'freestyle'
|
||||
if (useE2B) {
|
||||
logger.info(`[${requestId}] E2B status`, {
|
||||
enabled: e2bEnabled,
|
||||
hasApiKey: Boolean(process.env.E2B_API_KEY),
|
||||
language: lang,
|
||||
})
|
||||
let prologue = ''
|
||||
const epilogue = ''
|
||||
|
||||
// // Extract npm packages from code if needed
|
||||
// const importRegex =
|
||||
// /import\s+?(?:(?:(?:[\w*\s{},]*)\s+from\s+?)|)(?:(?:"([^"]*)")|(?:'([^']*)'))[^;]*/g
|
||||
// const requireRegex = /const\s+[\w\s{}]*\s*=\s*require\s*\(\s*['"]([^'"]+)['"]\s*\)/g
|
||||
if (lang === CodeLanguage.JavaScript) {
|
||||
// Track prologue lines for error adjustment
|
||||
let prologueLineCount = 0
|
||||
prologue += `const params = JSON.parse(${JSON.stringify(JSON.stringify(executionParams))});\n`
|
||||
prologueLineCount++
|
||||
prologue += `const environmentVariables = JSON.parse(${JSON.stringify(JSON.stringify(envVars))});\n`
|
||||
prologueLineCount++
|
||||
for (const [k, v] of Object.entries(contextVariables)) {
|
||||
prologue += `const ${k} = JSON.parse(${JSON.stringify(JSON.stringify(v))});\n`
|
||||
prologueLineCount++
|
||||
}
|
||||
const wrapped = [
|
||||
';(async () => {',
|
||||
' try {',
|
||||
' const __sim_result = await (async () => {',
|
||||
` ${resolvedCode.split('\n').join('\n ')}`,
|
||||
' })();',
|
||||
" console.log('__SIM_RESULT__=' + JSON.stringify(__sim_result));",
|
||||
' } catch (error) {',
|
||||
' console.log(String((error && (error.stack || error.message)) || error));',
|
||||
' throw error;',
|
||||
' }',
|
||||
'})();',
|
||||
].join('\n')
|
||||
const codeForE2B = prologue + wrapped + epilogue
|
||||
|
||||
// const packages: Record<string, string> = {}
|
||||
// const matches = [
|
||||
// ...resolvedCode.matchAll(importRegex),
|
||||
// ...resolvedCode.matchAll(requireRegex),
|
||||
// ]
|
||||
const execStart = Date.now()
|
||||
const {
|
||||
result: e2bResult,
|
||||
stdout: e2bStdout,
|
||||
sandboxId,
|
||||
error: e2bError,
|
||||
} = await executeInE2B({
|
||||
code: codeForE2B,
|
||||
language: CodeLanguage.JavaScript,
|
||||
timeoutMs: timeout,
|
||||
})
|
||||
const executionTime = Date.now() - execStart
|
||||
stdout += e2bStdout
|
||||
|
||||
// // Extract package names from import statements
|
||||
// for (const match of matches) {
|
||||
// const packageName = match[1] || match[2]
|
||||
// if (packageName && !packageName.startsWith('.') && !packageName.startsWith('/')) {
|
||||
// // Extract just the package name without version or subpath
|
||||
// const basePackageName = packageName.split('/')[0]
|
||||
// packages[basePackageName] = 'latest' // Use latest version
|
||||
// }
|
||||
// }
|
||||
logger.info(`[${requestId}] E2B JS sandbox`, {
|
||||
sandboxId,
|
||||
stdoutPreview: e2bStdout?.slice(0, 200),
|
||||
error: e2bError,
|
||||
})
|
||||
|
||||
// const freestyle = new FreestyleSandboxes({
|
||||
// apiKey: env.FREESTYLE_API_KEY,
|
||||
// })
|
||||
// If there was an execution error, format it properly
|
||||
if (e2bError) {
|
||||
const { formattedError, cleanedOutput } = formatE2BError(
|
||||
e2bError,
|
||||
e2bStdout,
|
||||
lang,
|
||||
resolvedCode,
|
||||
prologueLineCount
|
||||
)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: formattedError,
|
||||
output: { result: null, stdout: cleanedOutput, executionTime },
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
|
||||
// // Wrap code in export default to match Freestyle's expectations
|
||||
// const wrappedCode = isCustomTool
|
||||
// ? `export default async () => {
|
||||
// // For custom tools, directly declare parameters as variables
|
||||
// ${Object.entries(executionParams)
|
||||
// .map(([key, value]) => `const ${key} = ${safeJSONStringify(value)};`)
|
||||
// .join('\n ')}
|
||||
// ${resolvedCode}
|
||||
// }`
|
||||
// : `export default async () => { ${resolvedCode} }`
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
output: { result: e2bResult ?? null, stdout, executionTime },
|
||||
})
|
||||
}
|
||||
// Track prologue lines for error adjustment
|
||||
let prologueLineCount = 0
|
||||
prologue += 'import json\n'
|
||||
prologueLineCount++
|
||||
prologue += `params = json.loads(${JSON.stringify(JSON.stringify(executionParams))})\n`
|
||||
prologueLineCount++
|
||||
prologue += `environmentVariables = json.loads(${JSON.stringify(JSON.stringify(envVars))})\n`
|
||||
prologueLineCount++
|
||||
for (const [k, v] of Object.entries(contextVariables)) {
|
||||
prologue += `${k} = json.loads(${JSON.stringify(JSON.stringify(v))})\n`
|
||||
prologueLineCount++
|
||||
}
|
||||
const wrapped = [
|
||||
'def __sim_main__():',
|
||||
...resolvedCode.split('\n').map((l) => ` ${l}`),
|
||||
'__sim_result__ = __sim_main__()',
|
||||
"print('__SIM_RESULT__=' + json.dumps(__sim_result__))",
|
||||
].join('\n')
|
||||
const codeForE2B = prologue + wrapped + epilogue
|
||||
|
||||
// // Execute the code with Freestyle
|
||||
// const res = await freestyle.executeScript(wrappedCode, {
|
||||
// nodeModules: packages,
|
||||
// timeout: null,
|
||||
// envVars: envVars,
|
||||
// })
|
||||
const execStart = Date.now()
|
||||
const {
|
||||
result: e2bResult,
|
||||
stdout: e2bStdout,
|
||||
sandboxId,
|
||||
error: e2bError,
|
||||
} = await executeInE2B({
|
||||
code: codeForE2B,
|
||||
language: CodeLanguage.Python,
|
||||
timeoutMs: timeout,
|
||||
})
|
||||
const executionTime = Date.now() - execStart
|
||||
stdout += e2bStdout
|
||||
|
||||
// // Check for direct API error response
|
||||
// // Type assertion since the library types don't include error response
|
||||
// const response = res as { _type?: string; error?: string }
|
||||
// if (response._type === 'error' && response.error) {
|
||||
// logger.error(`[${requestId}] Freestyle returned error response`, {
|
||||
// error: response.error,
|
||||
// })
|
||||
// throw response.error
|
||||
// }
|
||||
logger.info(`[${requestId}] E2B Py sandbox`, {
|
||||
sandboxId,
|
||||
stdoutPreview: e2bStdout?.slice(0, 200),
|
||||
error: e2bError,
|
||||
})
|
||||
|
||||
// // Capture stdout/stderr from Freestyle logs
|
||||
// stdout =
|
||||
// res.logs
|
||||
// ?.map((log) => (log.type === 'error' ? 'ERROR: ' : '') + log.message)
|
||||
// .join('\n') || ''
|
||||
// If there was an execution error, format it properly
|
||||
if (e2bError) {
|
||||
const { formattedError, cleanedOutput } = formatE2BError(
|
||||
e2bError,
|
||||
e2bStdout,
|
||||
lang,
|
||||
resolvedCode,
|
||||
prologueLineCount
|
||||
)
|
||||
return NextResponse.json(
|
||||
{
|
||||
success: false,
|
||||
error: formattedError,
|
||||
output: { result: null, stdout: cleanedOutput, executionTime },
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
|
||||
// // Check for errors reported within Freestyle logs
|
||||
// const freestyleErrors = res.logs?.filter((log) => log.type === 'error') || []
|
||||
// if (freestyleErrors.length > 0) {
|
||||
// const errorMessage = freestyleErrors.map((log) => log.message).join('\n')
|
||||
// logger.error(`[${requestId}] Freestyle execution completed with script errors`, {
|
||||
// errorMessage,
|
||||
// stdout,
|
||||
// })
|
||||
// // Create a proper Error object to be caught by the outer handler
|
||||
// const scriptError = new Error(errorMessage)
|
||||
// scriptError.name = 'FreestyleScriptError'
|
||||
// throw scriptError
|
||||
// }
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
output: { result: e2bResult ?? null, stdout, executionTime },
|
||||
})
|
||||
}
|
||||
|
||||
// // If no errors, execution was successful
|
||||
// result = res.result
|
||||
// logger.info(`[${requestId}] Freestyle execution successful`, {
|
||||
// result,
|
||||
// stdout,
|
||||
// })
|
||||
// } catch (error: any) {
|
||||
// // Check if the error came from our explicit throw above due to script errors
|
||||
// if (error.name === 'FreestyleScriptError') {
|
||||
// throw error // Re-throw to be caught by the outer handler
|
||||
// }
|
||||
|
||||
// // Otherwise, it's likely a Freestyle API call error (network, auth, config, etc.) -> Fallback to VM
|
||||
// logger.error(`[${requestId}] Freestyle API call failed, falling back to VM:`, {
|
||||
// error: error.message,
|
||||
// stack: error.stack,
|
||||
// })
|
||||
// executionMethod = 'vm_fallback'
|
||||
|
||||
// // Continue to VM execution
|
||||
// const context = createContext({
|
||||
// params: executionParams,
|
||||
// environmentVariables: envVars,
|
||||
// console: {
|
||||
// log: (...args: any[]) => {
|
||||
// const logMessage = `${args
|
||||
// .map((arg) => (typeof arg === 'object' ? JSON.stringify(arg) : String(arg)))
|
||||
// .join(' ')}\n`
|
||||
// stdout += logMessage
|
||||
// },
|
||||
// error: (...args: any[]) => {
|
||||
// const errorMessage = `${args
|
||||
// .map((arg) => (typeof arg === 'object' ? JSON.stringify(arg) : String(arg)))
|
||||
// .join(' ')}\n`
|
||||
// logger.error(`[${requestId}] Code Console Error: ${errorMessage}`)
|
||||
// stdout += `ERROR: ${errorMessage}`
|
||||
// },
|
||||
// },
|
||||
// })
|
||||
|
||||
// const script = new Script(`
|
||||
// (async () => {
|
||||
// try {
|
||||
// ${
|
||||
// isCustomTool
|
||||
// ? `// For custom tools, make parameters directly accessible
|
||||
// ${Object.keys(executionParams)
|
||||
// .map((key) => `const ${key} = params.${key};`)
|
||||
// .join('\n ')}`
|
||||
// : ''
|
||||
// }
|
||||
// ${resolvedCode}
|
||||
// } catch (error) {
|
||||
// console.error(error);
|
||||
// throw error;
|
||||
// }
|
||||
// })()
|
||||
// `)
|
||||
|
||||
// result = await script.runInContext(context, {
|
||||
// timeout,
|
||||
// displayErrors: true,
|
||||
// })
|
||||
// }
|
||||
// } else {
|
||||
logger.info(`[${requestId}] Using VM for code execution`, {
|
||||
resolvedCode,
|
||||
hasEnvVars: Object.keys(envVars).length > 0,
|
||||
})
|
||||
|
||||
// Create a secure context with console logging
|
||||
const executionMethod = 'vm'
|
||||
const context = createContext({
|
||||
params: executionParams,
|
||||
environmentVariables: envVars,
|
||||
...contextVariables, // Add resolved variables directly to context
|
||||
fetch: globalThis.fetch || require('node-fetch').default,
|
||||
...contextVariables,
|
||||
fetch: (globalThis as any).fetch || require('node-fetch').default,
|
||||
console: {
|
||||
log: (...args: any[]) => {
|
||||
const logMessage = `${args
|
||||
@@ -539,23 +755,17 @@ export async function POST(req: NextRequest) {
|
||||
},
|
||||
})
|
||||
|
||||
// Calculate line offset for user code to provide accurate error reporting
|
||||
const wrapperLines = ['(async () => {', ' try {']
|
||||
|
||||
// Add custom tool parameter declarations if needed
|
||||
if (isCustomTool) {
|
||||
wrapperLines.push(' // For custom tools, make parameters directly accessible')
|
||||
Object.keys(executionParams).forEach((key) => {
|
||||
wrapperLines.push(` const ${key} = params.${key};`)
|
||||
})
|
||||
}
|
||||
|
||||
userCodeStartLine = wrapperLines.length + 1 // +1 because user code starts on next line
|
||||
|
||||
// Build the complete script with proper formatting for line numbers
|
||||
userCodeStartLine = wrapperLines.length + 1
|
||||
const fullScript = [
|
||||
...wrapperLines,
|
||||
` ${resolvedCode.split('\n').join('\n ')}`, // Indent user code
|
||||
` ${resolvedCode.split('\n').join('\n ')}`,
|
||||
' } catch (error) {',
|
||||
' console.error(error);',
|
||||
' throw error;',
|
||||
@@ -564,33 +774,26 @@ export async function POST(req: NextRequest) {
|
||||
].join('\n')
|
||||
|
||||
const script = new Script(fullScript, {
|
||||
filename: 'user-function.js', // This filename will appear in stack traces
|
||||
lineOffset: 0, // Start line numbering from 0
|
||||
columnOffset: 0, // Start column numbering from 0
|
||||
filename: 'user-function.js',
|
||||
lineOffset: 0,
|
||||
columnOffset: 0,
|
||||
})
|
||||
|
||||
const result = await script.runInContext(context, {
|
||||
timeout,
|
||||
displayErrors: true,
|
||||
breakOnSigint: true, // Allow breaking on SIGINT for better debugging
|
||||
breakOnSigint: true,
|
||||
})
|
||||
// }
|
||||
|
||||
const executionTime = Date.now() - startTime
|
||||
logger.info(`[${requestId}] Function executed successfully using ${executionMethod}`, {
|
||||
executionTime,
|
||||
})
|
||||
|
||||
const response = {
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
output: {
|
||||
result,
|
||||
stdout,
|
||||
executionTime,
|
||||
},
|
||||
}
|
||||
|
||||
return NextResponse.json(response)
|
||||
output: { result, stdout, executionTime },
|
||||
})
|
||||
} catch (error: any) {
|
||||
const executionTime = Date.now() - startTime
|
||||
logger.error(`[${requestId}] Function execution failed`, {
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { Resend } from 'resend'
|
||||
import { z } from 'zod'
|
||||
import { renderHelpConfirmationEmail } from '@/components/emails'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { sendEmail } from '@/lib/email/mailer'
|
||||
import { getFromEmailAddress } from '@/lib/email/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getEmailDomain } from '@/lib/urls/utils'
|
||||
|
||||
const resend = env.RESEND_API_KEY ? new Resend(env.RESEND_API_KEY) : null
|
||||
const logger = createLogger('HelpAPI')
|
||||
|
||||
const helpFormSchema = z.object({
|
||||
email: z.string().email('Invalid email address'),
|
||||
subject: z.string().min(1, 'Subject is required'),
|
||||
message: z.string().min(1, 'Message is required'),
|
||||
type: z.enum(['bug', 'feedback', 'feature_request', 'other']),
|
||||
@@ -19,23 +20,19 @@ export async function POST(req: NextRequest) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
// Check if Resend API key is configured
|
||||
if (!resend) {
|
||||
logger.error(`[${requestId}] RESEND_API_KEY not configured`)
|
||||
return NextResponse.json(
|
||||
{
|
||||
error:
|
||||
'Email service not configured. Please set RESEND_API_KEY in environment variables.',
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
// Get user session
|
||||
const session = await getSession()
|
||||
if (!session?.user?.email) {
|
||||
logger.warn(`[${requestId}] Unauthorized help request attempt`)
|
||||
return NextResponse.json({ error: 'Authentication required' }, { status: 401 })
|
||||
}
|
||||
|
||||
const email = session.user.email
|
||||
|
||||
// Handle multipart form data
|
||||
const formData = await req.formData()
|
||||
|
||||
// Extract form fields
|
||||
const email = formData.get('email') as string
|
||||
const subject = formData.get('subject') as string
|
||||
const message = formData.get('message') as string
|
||||
const type = formData.get('type') as string
|
||||
@@ -46,19 +43,18 @@ export async function POST(req: NextRequest) {
|
||||
})
|
||||
|
||||
// Validate the form data
|
||||
const result = helpFormSchema.safeParse({
|
||||
email,
|
||||
const validationResult = helpFormSchema.safeParse({
|
||||
subject,
|
||||
message,
|
||||
type,
|
||||
})
|
||||
|
||||
if (!result.success) {
|
||||
if (!validationResult.success) {
|
||||
logger.warn(`[${requestId}] Invalid help request data`, {
|
||||
errors: result.error.format(),
|
||||
errors: validationResult.error.format(),
|
||||
})
|
||||
return NextResponse.json(
|
||||
{ error: 'Invalid request data', details: result.error.format() },
|
||||
{ error: 'Invalid request data', details: validationResult.error.format() },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
@@ -96,63 +92,60 @@ ${message}
|
||||
emailText += `\n\n${images.length} image(s) attached.`
|
||||
}
|
||||
|
||||
// Send email using Resend
|
||||
const { data, error } = await resend.emails.send({
|
||||
from: `Sim <noreply@${getEmailDomain()}>`,
|
||||
to: [`help@${getEmailDomain()}`],
|
||||
const emailResult = await sendEmail({
|
||||
to: [`help@${env.EMAIL_DOMAIN || getEmailDomain()}`],
|
||||
subject: `[${type.toUpperCase()}] ${subject}`,
|
||||
replyTo: email,
|
||||
text: emailText,
|
||||
from: getFromEmailAddress(),
|
||||
replyTo: email,
|
||||
emailType: 'transactional',
|
||||
attachments: images.map((image) => ({
|
||||
filename: image.filename,
|
||||
content: image.content.toString('base64'),
|
||||
contentType: image.contentType,
|
||||
disposition: 'attachment', // Explicitly set as attachment
|
||||
disposition: 'attachment',
|
||||
})),
|
||||
})
|
||||
|
||||
if (error) {
|
||||
logger.error(`[${requestId}] Error sending help request email`, error)
|
||||
if (!emailResult.success) {
|
||||
logger.error(`[${requestId}] Error sending help request email`, emailResult.message)
|
||||
return NextResponse.json({ error: 'Failed to send email' }, { status: 500 })
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Help request email sent successfully`)
|
||||
|
||||
// Send confirmation email to the user
|
||||
await resend.emails
|
||||
.send({
|
||||
from: `Sim <noreply@${getEmailDomain()}>`,
|
||||
try {
|
||||
const confirmationHtml = await renderHelpConfirmationEmail(
|
||||
email,
|
||||
type as 'bug' | 'feedback' | 'feature_request' | 'other',
|
||||
images.length
|
||||
)
|
||||
|
||||
await sendEmail({
|
||||
to: [email],
|
||||
subject: `Your ${type} request has been received: ${subject}`,
|
||||
text: `
|
||||
Hello,
|
||||
|
||||
Thank you for your ${type} submission. We've received your request and will get back to you as soon as possible.
|
||||
|
||||
Your message:
|
||||
${message}
|
||||
|
||||
${images.length > 0 ? `You attached ${images.length} image(s).` : ''}
|
||||
|
||||
Best regards,
|
||||
The Sim Team
|
||||
`,
|
||||
replyTo: `help@${getEmailDomain()}`,
|
||||
})
|
||||
.catch((err) => {
|
||||
logger.warn(`[${requestId}] Failed to send confirmation email`, err)
|
||||
html: confirmationHtml,
|
||||
from: getFromEmailAddress(),
|
||||
replyTo: `help@${env.EMAIL_DOMAIN || getEmailDomain()}`,
|
||||
emailType: 'transactional',
|
||||
})
|
||||
} catch (err) {
|
||||
logger.warn(`[${requestId}] Failed to send confirmation email`, err)
|
||||
}
|
||||
|
||||
return NextResponse.json(
|
||||
{ success: true, message: 'Help request submitted successfully' },
|
||||
{ status: 200 }
|
||||
)
|
||||
} catch (error) {
|
||||
// Check if error is related to missing API key
|
||||
if (error instanceof Error && error.message.includes('API key')) {
|
||||
logger.error(`[${requestId}] API key configuration error`, error)
|
||||
if (error instanceof Error && error.message.includes('not configured')) {
|
||||
logger.error(`[${requestId}] Email service configuration error`, error)
|
||||
return NextResponse.json(
|
||||
{ error: 'Email service configuration error. Please check your RESEND_API_KEY.' },
|
||||
{
|
||||
error:
|
||||
'Email service configuration error. Please check your email service configuration.',
|
||||
},
|
||||
{ status: 500 }
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import { runs } from '@trigger.dev/sdk/v3'
|
||||
import { runs } from '@trigger.dev/sdk'
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
import { createErrorResponse } from '@/app/api/workflows/utils'
|
||||
import { db } from '@/db'
|
||||
import { apiKey as apiKeyTable } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('TaskStatusAPI')
|
||||
|
||||
@@ -1,15 +1,10 @@
|
||||
import { createHash, randomUUID } from 'crypto'
|
||||
import { eq, sql } from 'drizzle-orm'
|
||||
import { randomUUID } from 'crypto'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { deleteChunk, updateChunk } from '@/lib/knowledge/chunks/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
import { checkChunkAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('ChunkByIdAPI')
|
||||
|
||||
@@ -105,33 +100,7 @@ export async function PUT(
|
||||
try {
|
||||
const validatedData = UpdateChunkSchema.parse(body)
|
||||
|
||||
const updateData: Partial<{
|
||||
content: string
|
||||
contentLength: number
|
||||
tokenCount: number
|
||||
chunkHash: string
|
||||
enabled: boolean
|
||||
updatedAt: Date
|
||||
}> = {}
|
||||
|
||||
if (validatedData.content) {
|
||||
updateData.content = validatedData.content
|
||||
updateData.contentLength = validatedData.content.length
|
||||
// Update token count estimation (rough approximation: 4 chars per token)
|
||||
updateData.tokenCount = Math.ceil(validatedData.content.length / 4)
|
||||
updateData.chunkHash = createHash('sha256').update(validatedData.content).digest('hex')
|
||||
}
|
||||
|
||||
if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled
|
||||
|
||||
await db.update(embedding).set(updateData).where(eq(embedding.id, chunkId))
|
||||
|
||||
// Fetch the updated chunk
|
||||
const updatedChunk = await db
|
||||
.select()
|
||||
.from(embedding)
|
||||
.where(eq(embedding.id, chunkId))
|
||||
.limit(1)
|
||||
const updatedChunk = await updateChunk(chunkId, validatedData, requestId)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Chunk updated: ${chunkId} in document ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
@@ -139,7 +108,7 @@ export async function PUT(
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: updatedChunk[0],
|
||||
data: updatedChunk,
|
||||
})
|
||||
} catch (validationError) {
|
||||
if (validationError instanceof z.ZodError) {
|
||||
@@ -193,37 +162,7 @@ export async function DELETE(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Use transaction to atomically delete chunk and update document statistics
|
||||
await db.transaction(async (tx) => {
|
||||
// Get chunk data before deletion for statistics update
|
||||
const chunkToDelete = await tx
|
||||
.select({
|
||||
tokenCount: embedding.tokenCount,
|
||||
contentLength: embedding.contentLength,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(eq(embedding.id, chunkId))
|
||||
.limit(1)
|
||||
|
||||
if (chunkToDelete.length === 0) {
|
||||
throw new Error('Chunk not found')
|
||||
}
|
||||
|
||||
const chunk = chunkToDelete[0]
|
||||
|
||||
// Delete the chunk
|
||||
await tx.delete(embedding).where(eq(embedding.id, chunkId))
|
||||
|
||||
// Update document statistics
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: sql`${document.chunkCount} - 1`,
|
||||
tokenCount: sql`${document.tokenCount} - ${chunk.tokenCount}`,
|
||||
characterCount: sql`${document.characterCount} - ${chunk.contentLength}`,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
})
|
||||
await deleteChunk(chunkId, documentId, requestId)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Chunk deleted: ${chunkId} from document ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
|
||||
@@ -1,378 +0,0 @@
|
||||
/**
|
||||
* Tests for knowledge document chunks API route
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import {
|
||||
createMockRequest,
|
||||
mockAuth,
|
||||
mockConsoleLogger,
|
||||
mockDrizzleOrm,
|
||||
mockKnowledgeSchemas,
|
||||
} from '@/app/api/__test-utils__/utils'
|
||||
|
||||
mockKnowledgeSchemas()
|
||||
mockDrizzleOrm()
|
||||
mockConsoleLogger()
|
||||
|
||||
vi.mock('@/lib/tokenization/estimators', () => ({
|
||||
estimateTokenCount: vi.fn().mockReturnValue({ count: 452 }),
|
||||
}))
|
||||
|
||||
vi.mock('@/providers/utils', () => ({
|
||||
calculateCost: vi.fn().mockReturnValue({
|
||||
input: 0.00000904,
|
||||
output: 0,
|
||||
total: 0.00000904,
|
||||
pricing: {
|
||||
input: 0.02,
|
||||
output: 0,
|
||||
updatedAt: '2025-07-10',
|
||||
},
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/api/knowledge/utils', () => ({
|
||||
checkKnowledgeBaseAccess: vi.fn(),
|
||||
checkKnowledgeBaseWriteAccess: vi.fn(),
|
||||
checkDocumentAccess: vi.fn(),
|
||||
checkDocumentWriteAccess: vi.fn(),
|
||||
checkChunkAccess: vi.fn(),
|
||||
generateEmbeddings: vi.fn().mockResolvedValue([[0.1, 0.2, 0.3, 0.4, 0.5]]),
|
||||
processDocumentAsync: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('Knowledge Document Chunks API Route', () => {
|
||||
const mockAuth$ = mockAuth()
|
||||
|
||||
const mockDbChain = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockReturnThis(),
|
||||
offset: vi.fn().mockReturnThis(),
|
||||
insert: vi.fn().mockReturnThis(),
|
||||
values: vi.fn().mockResolvedValue(undefined),
|
||||
update: vi.fn().mockReturnThis(),
|
||||
set: vi.fn().mockReturnThis(),
|
||||
returning: vi.fn().mockResolvedValue([]),
|
||||
delete: vi.fn().mockReturnThis(),
|
||||
transaction: vi.fn(),
|
||||
}
|
||||
|
||||
const mockGetUserId = vi.fn()
|
||||
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
vi.doMock('@/db', () => ({
|
||||
db: mockDbChain,
|
||||
}))
|
||||
|
||||
vi.doMock('@/app/api/auth/oauth/utils', () => ({
|
||||
getUserId: mockGetUserId,
|
||||
}))
|
||||
|
||||
Object.values(mockDbChain).forEach((fn) => {
|
||||
if (typeof fn === 'function' && fn !== mockDbChain.values && fn !== mockDbChain.returning) {
|
||||
fn.mockClear().mockReturnThis()
|
||||
}
|
||||
})
|
||||
|
||||
vi.stubGlobal('crypto', {
|
||||
randomUUID: vi.fn().mockReturnValue('mock-chunk-uuid-1234'),
|
||||
createHash: vi.fn().mockReturnValue({
|
||||
update: vi.fn().mockReturnThis(),
|
||||
digest: vi.fn().mockReturnValue('mock-hash-123'),
|
||||
}),
|
||||
})
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('POST /api/knowledge/[id]/documents/[documentId]/chunks', () => {
|
||||
const validChunkData = {
|
||||
content: 'This is test chunk content for uploading to the knowledge base document.',
|
||||
enabled: true,
|
||||
}
|
||||
|
||||
const mockDocumentAccess = {
|
||||
hasAccess: true,
|
||||
notFound: false,
|
||||
reason: '',
|
||||
document: {
|
||||
id: 'doc-123',
|
||||
processingStatus: 'completed',
|
||||
tag1: 'tag1-value',
|
||||
tag2: 'tag2-value',
|
||||
tag3: null,
|
||||
tag4: null,
|
||||
tag5: null,
|
||||
tag6: null,
|
||||
tag7: null,
|
||||
},
|
||||
}
|
||||
|
||||
const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' })
|
||||
|
||||
it('should create chunk successfully with cost tracking', async () => {
|
||||
const { checkDocumentWriteAccess, generateEmbeddings } = await import(
|
||||
'@/app/api/knowledge/utils'
|
||||
)
|
||||
const { estimateTokenCount } = await import('@/lib/tokenization/estimators')
|
||||
const { calculateCost } = await import('@/providers/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
...mockDocumentAccess,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
} as any)
|
||||
|
||||
// Mock generateEmbeddings
|
||||
vi.mocked(generateEmbeddings).mockResolvedValue([[0.1, 0.2, 0.3]])
|
||||
|
||||
// Mock transaction
|
||||
const mockTx = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockResolvedValue([{ chunkIndex: 0 }]),
|
||||
insert: vi.fn().mockReturnThis(),
|
||||
values: vi.fn().mockResolvedValue(undefined),
|
||||
update: vi.fn().mockReturnThis(),
|
||||
set: vi.fn().mockReturnThis(),
|
||||
}
|
||||
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
return await callback(mockTx)
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
|
||||
// Verify cost tracking
|
||||
expect(data.data.cost).toBeDefined()
|
||||
expect(data.data.cost.input).toBe(0.00000904)
|
||||
expect(data.data.cost.output).toBe(0)
|
||||
expect(data.data.cost.total).toBe(0.00000904)
|
||||
expect(data.data.cost.tokens).toEqual({
|
||||
prompt: 452,
|
||||
completion: 0,
|
||||
total: 452,
|
||||
})
|
||||
expect(data.data.cost.model).toBe('text-embedding-3-small')
|
||||
expect(data.data.cost.pricing).toEqual({
|
||||
input: 0.02,
|
||||
output: 0,
|
||||
updatedAt: '2025-07-10',
|
||||
})
|
||||
|
||||
// Verify function calls
|
||||
expect(estimateTokenCount).toHaveBeenCalledWith(validChunkData.content, 'openai')
|
||||
expect(calculateCost).toHaveBeenCalledWith('text-embedding-3-small', 452, 0, false)
|
||||
})
|
||||
|
||||
it('should handle workflow-based authentication', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
const workflowData = {
|
||||
...validChunkData,
|
||||
workflowId: 'workflow-123',
|
||||
}
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
...mockDocumentAccess,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
} as any)
|
||||
|
||||
const mockTx = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
insert: vi.fn().mockReturnThis(),
|
||||
values: vi.fn().mockResolvedValue(undefined),
|
||||
update: vi.fn().mockReturnThis(),
|
||||
set: vi.fn().mockReturnThis(),
|
||||
}
|
||||
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
return await callback(mockTx)
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', workflowData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(mockGetUserId).toHaveBeenCalledWith(expect.any(String), 'workflow-123')
|
||||
})
|
||||
|
||||
it.concurrent('should return unauthorized for unauthenticated request', async () => {
|
||||
mockGetUserId.mockResolvedValue(null)
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
expect(data.error).toBe('Unauthorized')
|
||||
})
|
||||
|
||||
it('should return not found for workflow that does not exist', async () => {
|
||||
const workflowData = {
|
||||
...validChunkData,
|
||||
workflowId: 'nonexistent-workflow',
|
||||
}
|
||||
|
||||
mockGetUserId.mockResolvedValue(null)
|
||||
|
||||
const req = createMockRequest('POST', workflowData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
expect(data.error).toBe('Workflow not found')
|
||||
})
|
||||
|
||||
it.concurrent('should return not found for document access denied', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
hasAccess: false,
|
||||
notFound: true,
|
||||
reason: 'Document not found',
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
expect(data.error).toBe('Document not found')
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthorized document access', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
hasAccess: false,
|
||||
notFound: false,
|
||||
reason: 'Unauthorized access',
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(401)
|
||||
expect(data.error).toBe('Unauthorized')
|
||||
})
|
||||
|
||||
it('should reject chunks for failed documents', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
...mockDocumentAccess,
|
||||
document: {
|
||||
...mockDocumentAccess.document!,
|
||||
processingStatus: 'failed',
|
||||
},
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
} as any)
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data.error).toBe('Cannot add chunks to failed document')
|
||||
})
|
||||
|
||||
it.concurrent('should validate chunk data', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
...mockDocumentAccess,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
} as any)
|
||||
|
||||
const invalidData = {
|
||||
content: '', // Empty content
|
||||
enabled: true,
|
||||
}
|
||||
|
||||
const req = createMockRequest('POST', invalidData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(400)
|
||||
expect(data.error).toBe('Invalid request data')
|
||||
expect(data.details).toBeDefined()
|
||||
})
|
||||
|
||||
it('should inherit tags from parent document', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetUserId.mockResolvedValue('user-123')
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
...mockDocumentAccess,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
} as any)
|
||||
|
||||
const mockTx = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
where: vi.fn().mockReturnThis(),
|
||||
orderBy: vi.fn().mockReturnThis(),
|
||||
limit: vi.fn().mockResolvedValue([]),
|
||||
insert: vi.fn().mockReturnThis(),
|
||||
values: vi.fn().mockImplementation((data) => {
|
||||
// Verify that tags are inherited from document
|
||||
expect(data.tag1).toBe('tag1-value')
|
||||
expect(data.tag2).toBe('tag2-value')
|
||||
expect(data.tag3).toBe(null)
|
||||
return Promise.resolve(undefined)
|
||||
}),
|
||||
update: vi.fn().mockReturnThis(),
|
||||
set: vi.fn().mockReturnThis(),
|
||||
}
|
||||
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
return await callback(mockTx)
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validChunkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/[documentId]/chunks/route')
|
||||
await POST(req, { params: mockParams })
|
||||
|
||||
expect(mockTx.values).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// REMOVED: "should handle cost calculation with different content lengths" test - it was failing
|
||||
})
|
||||
})
|
||||
@@ -1,21 +1,11 @@
|
||||
import crypto from 'crypto'
|
||||
import { and, asc, eq, ilike, inArray, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { batchChunkOperation, createChunk, queryChunks } from '@/lib/knowledge/chunks/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
import { estimateTokenCount } from '@/lib/tokenization/estimators'
|
||||
import { getUserId } from '@/app/api/auth/oauth/utils'
|
||||
import {
|
||||
checkDocumentAccess,
|
||||
checkDocumentWriteAccess,
|
||||
generateEmbeddings,
|
||||
} from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding } from '@/db/schema'
|
||||
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
|
||||
import { calculateCost } from '@/providers/utils'
|
||||
|
||||
const logger = createLogger('DocumentChunksAPI')
|
||||
@@ -69,7 +59,6 @@ export async function GET(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if document processing is completed
|
||||
const doc = accessCheck.document
|
||||
if (!doc) {
|
||||
logger.warn(
|
||||
@@ -92,7 +81,6 @@ export async function GET(
|
||||
)
|
||||
}
|
||||
|
||||
// Parse query parameters
|
||||
const { searchParams } = new URL(req.url)
|
||||
const queryParams = GetChunksQuerySchema.parse({
|
||||
search: searchParams.get('search') || undefined,
|
||||
@@ -101,67 +89,12 @@ export async function GET(
|
||||
offset: searchParams.get('offset') || undefined,
|
||||
})
|
||||
|
||||
// Build query conditions
|
||||
const conditions = [eq(embedding.documentId, documentId)]
|
||||
|
||||
// Add enabled filter
|
||||
if (queryParams.enabled === 'true') {
|
||||
conditions.push(eq(embedding.enabled, true))
|
||||
} else if (queryParams.enabled === 'false') {
|
||||
conditions.push(eq(embedding.enabled, false))
|
||||
}
|
||||
|
||||
// Add search filter
|
||||
if (queryParams.search) {
|
||||
conditions.push(ilike(embedding.content, `%${queryParams.search}%`))
|
||||
}
|
||||
|
||||
// Fetch chunks
|
||||
const chunks = await db
|
||||
.select({
|
||||
id: embedding.id,
|
||||
chunkIndex: embedding.chunkIndex,
|
||||
content: embedding.content,
|
||||
contentLength: embedding.contentLength,
|
||||
tokenCount: embedding.tokenCount,
|
||||
enabled: embedding.enabled,
|
||||
startOffset: embedding.startOffset,
|
||||
endOffset: embedding.endOffset,
|
||||
tag1: embedding.tag1,
|
||||
tag2: embedding.tag2,
|
||||
tag3: embedding.tag3,
|
||||
tag4: embedding.tag4,
|
||||
tag5: embedding.tag5,
|
||||
tag6: embedding.tag6,
|
||||
tag7: embedding.tag7,
|
||||
createdAt: embedding.createdAt,
|
||||
updatedAt: embedding.updatedAt,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(and(...conditions))
|
||||
.orderBy(asc(embedding.chunkIndex))
|
||||
.limit(queryParams.limit)
|
||||
.offset(queryParams.offset)
|
||||
|
||||
// Get total count for pagination
|
||||
const totalCount = await db
|
||||
.select({ count: sql`count(*)` })
|
||||
.from(embedding)
|
||||
.where(and(...conditions))
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Retrieved ${chunks.length} chunks for document ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
const result = await queryChunks(documentId, queryParams, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: chunks,
|
||||
pagination: {
|
||||
total: Number(totalCount[0]?.count || 0),
|
||||
limit: queryParams.limit,
|
||||
offset: queryParams.offset,
|
||||
hasMore: chunks.length === queryParams.limit,
|
||||
},
|
||||
data: result.chunks,
|
||||
pagination: result.pagination,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error fetching chunks`, error)
|
||||
@@ -222,76 +155,27 @@ export async function POST(
|
||||
try {
|
||||
const validatedData = CreateChunkSchema.parse(searchParams)
|
||||
|
||||
// Generate embedding for the content first (outside transaction for performance)
|
||||
logger.info(`[${requestId}] Generating embedding for manual chunk`)
|
||||
const embeddings = await generateEmbeddings([validatedData.content])
|
||||
const docTags = {
|
||||
tag1: doc.tag1 ?? null,
|
||||
tag2: doc.tag2 ?? null,
|
||||
tag3: doc.tag3 ?? null,
|
||||
tag4: doc.tag4 ?? null,
|
||||
tag5: doc.tag5 ?? null,
|
||||
tag6: doc.tag6 ?? null,
|
||||
tag7: doc.tag7 ?? null,
|
||||
}
|
||||
|
||||
// Calculate accurate token count for both database storage and cost calculation
|
||||
const tokenCount = estimateTokenCount(validatedData.content, 'openai')
|
||||
const newChunk = await createChunk(
|
||||
knowledgeBaseId,
|
||||
documentId,
|
||||
docTags,
|
||||
validatedData,
|
||||
requestId
|
||||
)
|
||||
|
||||
const chunkId = crypto.randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
// Use transaction to atomically get next index and insert chunk
|
||||
const newChunk = await db.transaction(async (tx) => {
|
||||
// Get the next chunk index atomically within the transaction
|
||||
const lastChunk = await tx
|
||||
.select({ chunkIndex: embedding.chunkIndex })
|
||||
.from(embedding)
|
||||
.where(eq(embedding.documentId, documentId))
|
||||
.orderBy(sql`${embedding.chunkIndex} DESC`)
|
||||
.limit(1)
|
||||
|
||||
const nextChunkIndex = lastChunk.length > 0 ? lastChunk[0].chunkIndex + 1 : 0
|
||||
|
||||
const chunkData = {
|
||||
id: chunkId,
|
||||
knowledgeBaseId,
|
||||
documentId,
|
||||
chunkIndex: nextChunkIndex,
|
||||
chunkHash: crypto.createHash('sha256').update(validatedData.content).digest('hex'),
|
||||
content: validatedData.content,
|
||||
contentLength: validatedData.content.length,
|
||||
tokenCount: tokenCount.count, // Use accurate token count
|
||||
embedding: embeddings[0],
|
||||
embeddingModel: 'text-embedding-3-small',
|
||||
startOffset: 0, // Manual chunks don't have document offsets
|
||||
endOffset: validatedData.content.length,
|
||||
// Inherit tags from parent document
|
||||
tag1: doc.tag1,
|
||||
tag2: doc.tag2,
|
||||
tag3: doc.tag3,
|
||||
tag4: doc.tag4,
|
||||
tag5: doc.tag5,
|
||||
tag6: doc.tag6,
|
||||
tag7: doc.tag7,
|
||||
enabled: validatedData.enabled,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
// Insert the new chunk
|
||||
await tx.insert(embedding).values(chunkData)
|
||||
|
||||
// Update document statistics
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: sql`${document.chunkCount} + 1`,
|
||||
tokenCount: sql`${document.tokenCount} + ${chunkData.tokenCount}`,
|
||||
characterCount: sql`${document.characterCount} + ${chunkData.contentLength}`,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
|
||||
return chunkData
|
||||
})
|
||||
|
||||
logger.info(`[${requestId}] Manual chunk created: ${chunkId} in document ${documentId}`)
|
||||
|
||||
// Calculate cost for the embedding (with fallback if calculation fails)
|
||||
let cost = null
|
||||
try {
|
||||
cost = calculateCost('text-embedding-3-small', tokenCount.count, 0, false)
|
||||
cost = calculateCost('text-embedding-3-small', newChunk.tokenCount, 0, false)
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Failed to calculate cost for chunk upload`, {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
@@ -303,6 +187,8 @@ export async function POST(
|
||||
success: true,
|
||||
data: {
|
||||
...newChunk,
|
||||
documentId,
|
||||
documentName: doc.filename,
|
||||
...(cost
|
||||
? {
|
||||
cost: {
|
||||
@@ -310,9 +196,9 @@ export async function POST(
|
||||
output: cost.output,
|
||||
total: cost.total,
|
||||
tokens: {
|
||||
prompt: tokenCount.count,
|
||||
prompt: newChunk.tokenCount,
|
||||
completion: 0,
|
||||
total: tokenCount.count,
|
||||
total: newChunk.tokenCount,
|
||||
},
|
||||
model: 'text-embedding-3-small',
|
||||
pricing: cost.pricing,
|
||||
@@ -374,92 +260,16 @@ export async function PATCH(
|
||||
const validatedData = BatchOperationSchema.parse(body)
|
||||
const { operation, chunkIds } = validatedData
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Starting batch ${operation} operation on ${chunkIds.length} chunks for document ${documentId}`
|
||||
)
|
||||
|
||||
const results = []
|
||||
let successCount = 0
|
||||
const errorCount = 0
|
||||
|
||||
if (operation === 'delete') {
|
||||
// Handle batch delete with transaction for consistency
|
||||
await db.transaction(async (tx) => {
|
||||
// Get chunks to delete for statistics update
|
||||
const chunksToDelete = await tx
|
||||
.select({
|
||||
id: embedding.id,
|
||||
tokenCount: embedding.tokenCount,
|
||||
contentLength: embedding.contentLength,
|
||||
})
|
||||
.from(embedding)
|
||||
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
|
||||
|
||||
if (chunksToDelete.length === 0) {
|
||||
throw new Error('No valid chunks found to delete')
|
||||
}
|
||||
|
||||
// Delete chunks
|
||||
await tx
|
||||
.delete(embedding)
|
||||
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
|
||||
|
||||
// Update document statistics
|
||||
const totalTokens = chunksToDelete.reduce((sum, chunk) => sum + chunk.tokenCount, 0)
|
||||
const totalCharacters = chunksToDelete.reduce(
|
||||
(sum, chunk) => sum + chunk.contentLength,
|
||||
0
|
||||
)
|
||||
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: sql`${document.chunkCount} - ${chunksToDelete.length}`,
|
||||
tokenCount: sql`${document.tokenCount} - ${totalTokens}`,
|
||||
characterCount: sql`${document.characterCount} - ${totalCharacters}`,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
|
||||
successCount = chunksToDelete.length
|
||||
results.push({
|
||||
operation: 'delete',
|
||||
deletedCount: chunksToDelete.length,
|
||||
chunkIds: chunksToDelete.map((c) => c.id),
|
||||
})
|
||||
})
|
||||
} else {
|
||||
// Handle batch enable/disable
|
||||
const enabled = operation === 'enable'
|
||||
|
||||
// Update chunks in a single query
|
||||
const updateResult = await db
|
||||
.update(embedding)
|
||||
.set({
|
||||
enabled,
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(and(eq(embedding.documentId, documentId), inArray(embedding.id, chunkIds)))
|
||||
.returning({ id: embedding.id })
|
||||
|
||||
successCount = updateResult.length
|
||||
results.push({
|
||||
operation,
|
||||
updatedCount: updateResult.length,
|
||||
chunkIds: updateResult.map((r) => r.id),
|
||||
})
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Batch ${operation} operation completed: ${successCount} successful, ${errorCount} errors`
|
||||
)
|
||||
const result = await batchChunkOperation(documentId, operation, chunkIds, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
operation,
|
||||
successCount,
|
||||
errorCount,
|
||||
results,
|
||||
successCount: result.processed,
|
||||
errorCount: result.errors.length,
|
||||
processed: result.processed,
|
||||
errors: result.errors,
|
||||
},
|
||||
})
|
||||
} catch (validationError) {
|
||||
|
||||
@@ -24,7 +24,14 @@ vi.mock('@/app/api/knowledge/utils', () => ({
|
||||
processDocumentAsync: vi.fn(),
|
||||
}))
|
||||
|
||||
// Setup common mocks
|
||||
vi.mock('@/lib/knowledge/documents/service', () => ({
|
||||
updateDocument: vi.fn(),
|
||||
deleteDocument: vi.fn(),
|
||||
markDocumentAsFailedTimeout: vi.fn(),
|
||||
retryDocumentProcessing: vi.fn(),
|
||||
processDocumentAsync: vi.fn(),
|
||||
}))
|
||||
|
||||
mockDrizzleOrm()
|
||||
mockConsoleLogger()
|
||||
|
||||
@@ -42,8 +49,6 @@ describe('Document By ID API Route', () => {
|
||||
transaction: vi.fn(),
|
||||
}
|
||||
|
||||
// Mock functions will be imported dynamically in tests
|
||||
|
||||
const mockDocument = {
|
||||
id: 'doc-123',
|
||||
knowledgeBaseId: 'kb-123',
|
||||
@@ -73,7 +78,6 @@ describe('Document By ID API Route', () => {
|
||||
}
|
||||
}
|
||||
})
|
||||
// Mock functions are cleared automatically by vitest
|
||||
}
|
||||
|
||||
beforeEach(async () => {
|
||||
@@ -83,8 +87,6 @@ describe('Document By ID API Route', () => {
|
||||
db: mockDbChain,
|
||||
}))
|
||||
|
||||
// Utils are mocked at the top level
|
||||
|
||||
vi.stubGlobal('crypto', {
|
||||
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),
|
||||
})
|
||||
@@ -195,6 +197,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should update document successfully', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { updateDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
@@ -203,31 +206,12 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Create a sequence of mocks for the database operations
|
||||
const updateChain = {
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue(undefined), // Update operation completes
|
||||
}),
|
||||
const updatedDocument = {
|
||||
...mockDocument,
|
||||
...validUpdateData,
|
||||
deletedAt: null,
|
||||
}
|
||||
|
||||
const selectChain = {
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi.fn().mockResolvedValue([{ ...mockDocument, ...validUpdateData }]),
|
||||
}),
|
||||
}),
|
||||
}
|
||||
|
||||
// Mock transaction
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
const mockTx = {
|
||||
update: vi.fn().mockReturnValue(updateChain),
|
||||
}
|
||||
await callback(mockTx)
|
||||
})
|
||||
|
||||
// Mock db operations in sequence
|
||||
mockDbChain.select.mockReturnValue(selectChain)
|
||||
vi.mocked(updateDocument).mockResolvedValue(updatedDocument)
|
||||
|
||||
const req = createMockRequest('PUT', validUpdateData)
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
@@ -238,8 +222,11 @@ describe('Document By ID API Route', () => {
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.filename).toBe('updated-document.pdf')
|
||||
expect(data.data.enabled).toBe(false)
|
||||
expect(mockDbChain.transaction).toHaveBeenCalled()
|
||||
expect(mockDbChain.select).toHaveBeenCalled()
|
||||
expect(vi.mocked(updateDocument)).toHaveBeenCalledWith(
|
||||
'doc-123',
|
||||
validUpdateData,
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should validate update data', async () => {
|
||||
@@ -274,6 +261,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should mark document as failed due to timeout successfully', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { markDocumentAsFailedTimeout } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
const processingDocument = {
|
||||
...mockDocument,
|
||||
@@ -288,34 +276,11 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Create a sequence of mocks for the database operations
|
||||
const updateChain = {
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue(undefined), // Update operation completes
|
||||
}),
|
||||
}
|
||||
|
||||
const selectChain = {
|
||||
from: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockReturnValue({
|
||||
limit: vi
|
||||
.fn()
|
||||
.mockResolvedValue([{ ...processingDocument, processingStatus: 'failed' }]),
|
||||
}),
|
||||
}),
|
||||
}
|
||||
|
||||
// Mock transaction
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
const mockTx = {
|
||||
update: vi.fn().mockReturnValue(updateChain),
|
||||
}
|
||||
await callback(mockTx)
|
||||
vi.mocked(markDocumentAsFailedTimeout).mockResolvedValue({
|
||||
success: true,
|
||||
processingDuration: 200000,
|
||||
})
|
||||
|
||||
// Mock db operations in sequence
|
||||
mockDbChain.select.mockReturnValue(selectChain)
|
||||
|
||||
const req = createMockRequest('PUT', { markFailedDueToTimeout: true })
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
const response = await PUT(req, { params: mockParams })
|
||||
@@ -323,13 +288,13 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(mockDbChain.transaction).toHaveBeenCalled()
|
||||
expect(updateChain.set).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
processingStatus: 'failed',
|
||||
processingError: 'Processing timed out - background process may have been terminated',
|
||||
processingCompletedAt: expect.any(Date),
|
||||
})
|
||||
expect(data.data.documentId).toBe('doc-123')
|
||||
expect(data.data.status).toBe('failed')
|
||||
expect(data.data.message).toBe('Document marked as failed due to timeout')
|
||||
expect(vi.mocked(markDocumentAsFailedTimeout)).toHaveBeenCalledWith(
|
||||
'doc-123',
|
||||
processingDocument.processingStartedAt,
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
@@ -354,6 +319,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should reject marking failed for recently started processing', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { markDocumentAsFailedTimeout } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
const recentProcessingDocument = {
|
||||
...mockDocument,
|
||||
@@ -368,6 +334,10 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
vi.mocked(markDocumentAsFailedTimeout).mockRejectedValue(
|
||||
new Error('Document has not been processing long enough to be considered dead')
|
||||
)
|
||||
|
||||
const req = createMockRequest('PUT', { markFailedDueToTimeout: true })
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
const response = await PUT(req, { params: mockParams })
|
||||
@@ -382,9 +352,8 @@ describe('Document By ID API Route', () => {
|
||||
const mockParams = Promise.resolve({ id: 'kb-123', documentId: 'doc-123' })
|
||||
|
||||
it('should retry processing successfully', async () => {
|
||||
const { checkDocumentWriteAccess, processDocumentAsync } = await import(
|
||||
'@/app/api/knowledge/utils'
|
||||
)
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { retryDocumentProcessing } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
const failedDocument = {
|
||||
...mockDocument,
|
||||
@@ -399,23 +368,12 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock transaction
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
const mockTx = {
|
||||
delete: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
update: vi.fn().mockReturnValue({
|
||||
set: vi.fn().mockReturnValue({
|
||||
where: vi.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
}),
|
||||
}
|
||||
return await callback(mockTx)
|
||||
vi.mocked(retryDocumentProcessing).mockResolvedValue({
|
||||
success: true,
|
||||
status: 'pending',
|
||||
message: 'Document retry processing started',
|
||||
})
|
||||
|
||||
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
|
||||
|
||||
const req = createMockRequest('PUT', { retryProcessing: true })
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
const response = await PUT(req, { params: mockParams })
|
||||
@@ -425,8 +383,17 @@ describe('Document By ID API Route', () => {
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.status).toBe('pending')
|
||||
expect(data.data.message).toBe('Document retry processing started')
|
||||
expect(mockDbChain.transaction).toHaveBeenCalled()
|
||||
expect(vi.mocked(processDocumentAsync)).toHaveBeenCalled()
|
||||
expect(vi.mocked(retryDocumentProcessing)).toHaveBeenCalledWith(
|
||||
'kb-123',
|
||||
'doc-123',
|
||||
{
|
||||
filename: failedDocument.filename,
|
||||
fileUrl: failedDocument.fileUrl,
|
||||
fileSize: failedDocument.fileSize,
|
||||
mimeType: failedDocument.mimeType,
|
||||
},
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should reject retry for non-failed document', async () => {
|
||||
@@ -486,6 +453,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should handle database errors during update', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { updateDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
@@ -494,8 +462,7 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock transaction to throw an error
|
||||
mockDbChain.transaction.mockRejectedValue(new Error('Database error'))
|
||||
vi.mocked(updateDocument).mockRejectedValue(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('PUT', validUpdateData)
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
@@ -512,6 +479,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should delete document successfully', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { deleteDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
@@ -520,10 +488,10 @@ describe('Document By ID API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Properly chain the mock database operations for soft delete
|
||||
mockDbChain.update.mockReturnValue(mockDbChain)
|
||||
mockDbChain.set.mockReturnValue(mockDbChain)
|
||||
mockDbChain.where.mockResolvedValue(undefined) // Update operation resolves
|
||||
vi.mocked(deleteDocument).mockResolvedValue({
|
||||
success: true,
|
||||
message: 'Document deleted successfully',
|
||||
})
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
@@ -533,12 +501,7 @@ describe('Document By ID API Route', () => {
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.message).toBe('Document deleted successfully')
|
||||
expect(mockDbChain.update).toHaveBeenCalled()
|
||||
expect(mockDbChain.set).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
deletedAt: expect.any(Date),
|
||||
})
|
||||
)
|
||||
expect(vi.mocked(deleteDocument)).toHaveBeenCalledWith('doc-123', expect.any(String))
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthenticated user', async () => {
|
||||
@@ -592,6 +555,7 @@ describe('Document By ID API Route', () => {
|
||||
|
||||
it('should handle database errors during deletion', async () => {
|
||||
const { checkDocumentWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { deleteDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkDocumentWriteAccess).mockResolvedValue({
|
||||
@@ -599,7 +563,7 @@ describe('Document By ID API Route', () => {
|
||||
document: mockDocument,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
mockDbChain.set.mockRejectedValue(new Error('Database error'))
|
||||
vi.mocked(deleteDocument).mockRejectedValue(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/documents/[documentId]/route')
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { TAG_SLOTS } from '@/lib/constants/knowledge'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
import {
|
||||
checkDocumentAccess,
|
||||
checkDocumentWriteAccess,
|
||||
processDocumentAsync,
|
||||
} from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding } from '@/db/schema'
|
||||
deleteDocument,
|
||||
markDocumentAsFailedTimeout,
|
||||
retryDocumentProcessing,
|
||||
updateDocument,
|
||||
} from '@/lib/knowledge/documents/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
|
||||
|
||||
const logger = createLogger('DocumentByIdAPI')
|
||||
|
||||
@@ -116,9 +111,7 @@ export async function PUT(
|
||||
|
||||
const updateData: any = {}
|
||||
|
||||
// Handle special operations first
|
||||
if (validatedData.markFailedDueToTimeout) {
|
||||
// Mark document as failed due to timeout (replaces mark-failed endpoint)
|
||||
const doc = accessCheck.document
|
||||
|
||||
if (doc.processingStatus !== 'processing') {
|
||||
@@ -135,58 +128,30 @@ export async function PUT(
|
||||
)
|
||||
}
|
||||
|
||||
const now = new Date()
|
||||
const processingDuration = now.getTime() - new Date(doc.processingStartedAt).getTime()
|
||||
const DEAD_PROCESS_THRESHOLD_MS = 150 * 1000
|
||||
try {
|
||||
await markDocumentAsFailedTimeout(documentId, doc.processingStartedAt, requestId)
|
||||
|
||||
if (processingDuration <= DEAD_PROCESS_THRESHOLD_MS) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Document has not been processing long enough to be considered dead' },
|
||||
{ status: 400 }
|
||||
)
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
documentId,
|
||||
status: 'failed',
|
||||
message: 'Document marked as failed due to timeout',
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
return NextResponse.json({ error: error.message }, { status: 400 })
|
||||
}
|
||||
throw error
|
||||
}
|
||||
|
||||
updateData.processingStatus = 'failed'
|
||||
updateData.processingError =
|
||||
'Processing timed out - background process may have been terminated'
|
||||
updateData.processingCompletedAt = now
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Marked document ${documentId} as failed due to dead process (processing time: ${Math.round(processingDuration / 1000)}s)`
|
||||
)
|
||||
} else if (validatedData.retryProcessing) {
|
||||
// Retry processing (replaces retry endpoint)
|
||||
const doc = accessCheck.document
|
||||
|
||||
if (doc.processingStatus !== 'failed') {
|
||||
return NextResponse.json({ error: 'Document is not in failed state' }, { status: 400 })
|
||||
}
|
||||
|
||||
// Clear existing embeddings and reset document state
|
||||
await db.transaction(async (tx) => {
|
||||
await tx.delete(embedding).where(eq(embedding.documentId, documentId))
|
||||
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
processingStatus: 'pending',
|
||||
processingStartedAt: null,
|
||||
processingCompletedAt: null,
|
||||
processingError: null,
|
||||
chunkCount: 0,
|
||||
tokenCount: 0,
|
||||
characterCount: 0,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
})
|
||||
|
||||
const processingOptions = {
|
||||
chunkSize: 1024,
|
||||
minCharactersPerChunk: 24,
|
||||
recipe: 'default',
|
||||
lang: 'en',
|
||||
}
|
||||
|
||||
const docData = {
|
||||
filename: doc.filename,
|
||||
fileUrl: doc.fileUrl,
|
||||
@@ -194,80 +159,33 @@ export async function PUT(
|
||||
mimeType: doc.mimeType,
|
||||
}
|
||||
|
||||
processDocumentAsync(knowledgeBaseId, documentId, docData, processingOptions).catch(
|
||||
(error: unknown) => {
|
||||
logger.error(`[${requestId}] Background retry processing error:`, error)
|
||||
}
|
||||
const result = await retryDocumentProcessing(
|
||||
knowledgeBaseId,
|
||||
documentId,
|
||||
docData,
|
||||
requestId
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Document retry initiated: ${documentId}`)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
documentId,
|
||||
status: 'pending',
|
||||
message: 'Document retry processing started',
|
||||
status: result.status,
|
||||
message: result.message,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
// Regular field updates
|
||||
if (validatedData.filename !== undefined) updateData.filename = validatedData.filename
|
||||
if (validatedData.enabled !== undefined) updateData.enabled = validatedData.enabled
|
||||
if (validatedData.chunkCount !== undefined) updateData.chunkCount = validatedData.chunkCount
|
||||
if (validatedData.tokenCount !== undefined) updateData.tokenCount = validatedData.tokenCount
|
||||
if (validatedData.characterCount !== undefined)
|
||||
updateData.characterCount = validatedData.characterCount
|
||||
if (validatedData.processingStatus !== undefined)
|
||||
updateData.processingStatus = validatedData.processingStatus
|
||||
if (validatedData.processingError !== undefined)
|
||||
updateData.processingError = validatedData.processingError
|
||||
const updatedDocument = await updateDocument(documentId, validatedData, requestId)
|
||||
|
||||
// Tag field updates
|
||||
TAG_SLOTS.forEach((slot) => {
|
||||
if ((validatedData as any)[slot] !== undefined) {
|
||||
;(updateData as any)[slot] = (validatedData as any)[slot]
|
||||
}
|
||||
logger.info(
|
||||
`[${requestId}] Document updated: ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: updatedDocument,
|
||||
})
|
||||
}
|
||||
|
||||
await db.transaction(async (tx) => {
|
||||
// Update the document
|
||||
await tx.update(document).set(updateData).where(eq(document.id, documentId))
|
||||
|
||||
// If any tag fields were updated, also update the embeddings
|
||||
const hasTagUpdates = TAG_SLOTS.some((field) => (validatedData as any)[field] !== undefined)
|
||||
|
||||
if (hasTagUpdates) {
|
||||
const embeddingUpdateData: Record<string, string | null> = {}
|
||||
TAG_SLOTS.forEach((field) => {
|
||||
if ((validatedData as any)[field] !== undefined) {
|
||||
embeddingUpdateData[field] = (validatedData as any)[field] || null
|
||||
}
|
||||
})
|
||||
|
||||
await tx
|
||||
.update(embedding)
|
||||
.set(embeddingUpdateData)
|
||||
.where(eq(embedding.documentId, documentId))
|
||||
}
|
||||
})
|
||||
|
||||
// Fetch the updated document
|
||||
const updatedDocument = await db
|
||||
.select()
|
||||
.from(document)
|
||||
.where(eq(document.id, documentId))
|
||||
.limit(1)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Document updated: ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: updatedDocument[0],
|
||||
})
|
||||
} catch (validationError) {
|
||||
if (validationError instanceof z.ZodError) {
|
||||
logger.warn(`[${requestId}] Invalid document update data`, {
|
||||
@@ -316,13 +234,7 @@ export async function DELETE(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Soft delete by setting deletedAt timestamp
|
||||
await db
|
||||
.update(document)
|
||||
.set({
|
||||
deletedAt: new Date(),
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
const result = await deleteDocument(documentId, requestId)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Document deleted: ${documentId} from knowledge base ${knowledgeBaseId}`
|
||||
@@ -330,7 +242,7 @@ export async function DELETE(
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: { message: 'Document deleted successfully' },
|
||||
data: result,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error deleting document`, error)
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { SUPPORTED_FIELD_TYPES } from '@/lib/knowledge/consts'
|
||||
import {
|
||||
getMaxSlotsForFieldType,
|
||||
getSlotsForFieldType,
|
||||
SUPPORTED_FIELD_TYPES,
|
||||
} from '@/lib/constants/knowledge'
|
||||
cleanupUnusedTagDefinitions,
|
||||
createOrUpdateTagDefinitionsBulk,
|
||||
deleteAllTagDefinitions,
|
||||
getDocumentTagDefinitions,
|
||||
} from '@/lib/knowledge/tags/service'
|
||||
import type { BulkTagDefinitionsData } from '@/lib/knowledge/tags/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
import { checkDocumentAccess, checkDocumentWriteAccess } from '@/app/api/knowledge/utils'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
@@ -29,106 +29,6 @@ const BulkTagDefinitionsSchema = z.object({
|
||||
definitions: z.array(TagDefinitionSchema),
|
||||
})
|
||||
|
||||
// Helper function to get the next available slot for a knowledge base and field type
|
||||
async function getNextAvailableSlot(
|
||||
knowledgeBaseId: string,
|
||||
fieldType: string,
|
||||
existingBySlot?: Map<string, any>
|
||||
): Promise<string | null> {
|
||||
// Get available slots for this field type
|
||||
const availableSlots = getSlotsForFieldType(fieldType)
|
||||
let usedSlots: Set<string>
|
||||
|
||||
if (existingBySlot) {
|
||||
// Use provided map if available (for performance in batch operations)
|
||||
// Filter by field type
|
||||
usedSlots = new Set(
|
||||
Array.from(existingBySlot.entries())
|
||||
.filter(([_, def]) => def.fieldType === fieldType)
|
||||
.map(([slot, _]) => slot)
|
||||
)
|
||||
} else {
|
||||
// Query database for existing tag definitions of the same field type
|
||||
const existingDefinitions = await db
|
||||
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
|
||||
)
|
||||
)
|
||||
|
||||
usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot))
|
||||
}
|
||||
|
||||
// Find the first available slot for this field type
|
||||
for (const slot of availableSlots) {
|
||||
if (!usedSlots.has(slot)) {
|
||||
return slot
|
||||
}
|
||||
}
|
||||
|
||||
return null // No available slots for this field type
|
||||
}
|
||||
|
||||
// Helper function to clean up unused tag definitions
|
||||
async function cleanupUnusedTagDefinitions(knowledgeBaseId: string, requestId: string) {
|
||||
try {
|
||||
logger.info(`[${requestId}] Starting cleanup for KB ${knowledgeBaseId}`)
|
||||
|
||||
// Get all tag definitions for this KB
|
||||
const allDefinitions = await db
|
||||
.select()
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
|
||||
logger.info(`[${requestId}] Found ${allDefinitions.length} tag definitions to check`)
|
||||
|
||||
if (allDefinitions.length === 0) {
|
||||
return 0
|
||||
}
|
||||
|
||||
let cleanedCount = 0
|
||||
|
||||
// For each tag definition, check if any documents use that tag slot
|
||||
for (const definition of allDefinitions) {
|
||||
const slot = definition.tagSlot
|
||||
|
||||
// Use raw SQL with proper column name injection
|
||||
const countResult = await db.execute(sql`
|
||||
SELECT count(*) as count
|
||||
FROM document
|
||||
WHERE knowledge_base_id = ${knowledgeBaseId}
|
||||
AND ${sql.raw(slot)} IS NOT NULL
|
||||
AND trim(${sql.raw(slot)}) != ''
|
||||
`)
|
||||
const count = Number(countResult[0]?.count) || 0
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Tag ${definition.displayName} (${slot}): ${count} documents using it`
|
||||
)
|
||||
|
||||
// If count is 0, remove this tag definition
|
||||
if (count === 0) {
|
||||
await db
|
||||
.delete(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, definition.id))
|
||||
|
||||
cleanedCount++
|
||||
logger.info(
|
||||
`[${requestId}] Removed unused tag definition: ${definition.displayName} (${definition.tagSlot})`
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return cleanedCount
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Failed to cleanup unused tag definitions:`, error)
|
||||
return 0 // Don't fail the main operation if cleanup fails
|
||||
}
|
||||
}
|
||||
|
||||
// GET /api/knowledge/[id]/documents/[documentId]/tag-definitions - Get tag definitions for a document
|
||||
export async function GET(
|
||||
req: NextRequest,
|
||||
@@ -145,35 +45,22 @@ export async function GET(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Verify document exists and belongs to the knowledge base
|
||||
const documentExists = await db
|
||||
.select({ id: document.id })
|
||||
.from(document)
|
||||
.where(and(eq(document.id, documentId), eq(document.knowledgeBaseId, knowledgeBaseId)))
|
||||
.limit(1)
|
||||
|
||||
if (documentExists.length === 0) {
|
||||
return NextResponse.json({ error: 'Document not found' }, { status: 404 })
|
||||
const accessCheck = await checkDocumentAccess(knowledgeBaseId, documentId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
if (accessCheck.notFound) {
|
||||
logger.warn(
|
||||
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
|
||||
)
|
||||
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
|
||||
}
|
||||
logger.warn(
|
||||
`[${requestId}] User ${session.user.id} attempted unauthorized document access: ${accessCheck.reason}`
|
||||
)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Get tag definitions for the knowledge base
|
||||
const tagDefinitions = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
fieldType: knowledgeBaseTagDefinitions.fieldType,
|
||||
createdAt: knowledgeBaseTagDefinitions.createdAt,
|
||||
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
const tagDefinitions = await getDocumentTagDefinitions(knowledgeBaseId)
|
||||
|
||||
logger.info(`[${requestId}] Retrieved ${tagDefinitions.length} tag definitions`)
|
||||
|
||||
@@ -203,21 +90,19 @@ export async function POST(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has write access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseWriteAccess(knowledgeBaseId, session.user.id)
|
||||
// Verify document exists and user has write access
|
||||
const accessCheck = await checkDocumentWriteAccess(knowledgeBaseId, documentId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Verify document exists and belongs to the knowledge base
|
||||
const documentExists = await db
|
||||
.select({ id: document.id })
|
||||
.from(document)
|
||||
.where(and(eq(document.id, documentId), eq(document.knowledgeBaseId, knowledgeBaseId)))
|
||||
.limit(1)
|
||||
|
||||
if (documentExists.length === 0) {
|
||||
return NextResponse.json({ error: 'Document not found' }, { status: 404 })
|
||||
if (accessCheck.notFound) {
|
||||
logger.warn(
|
||||
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
|
||||
)
|
||||
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
|
||||
}
|
||||
logger.warn(
|
||||
`[${requestId}] User ${session.user.id} attempted unauthorized document write access: ${accessCheck.reason}`
|
||||
)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
let body
|
||||
@@ -238,197 +123,24 @@ export async function POST(
|
||||
|
||||
const validatedData = BulkTagDefinitionsSchema.parse(body)
|
||||
|
||||
// Validate slots are valid for their field types
|
||||
for (const definition of validatedData.definitions) {
|
||||
const validSlots = getSlotsForFieldType(definition.fieldType)
|
||||
if (validSlots.length === 0) {
|
||||
return NextResponse.json(
|
||||
{ error: `Unsupported field type: ${definition.fieldType}` },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
|
||||
if (!validSlots.includes(definition.tagSlot)) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: `Invalid slot '${definition.tagSlot}' for field type '${definition.fieldType}'. Valid slots: ${validSlots.join(', ')}`,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
const bulkData: BulkTagDefinitionsData = {
|
||||
definitions: validatedData.definitions.map((def) => ({
|
||||
tagSlot: def.tagSlot,
|
||||
displayName: def.displayName,
|
||||
fieldType: def.fieldType,
|
||||
originalDisplayName: def._originalDisplayName,
|
||||
})),
|
||||
}
|
||||
|
||||
// Validate no duplicate tag slots within the same field type
|
||||
const slotsByFieldType = new Map<string, Set<string>>()
|
||||
for (const definition of validatedData.definitions) {
|
||||
if (!slotsByFieldType.has(definition.fieldType)) {
|
||||
slotsByFieldType.set(definition.fieldType, new Set())
|
||||
}
|
||||
const slotsForType = slotsByFieldType.get(definition.fieldType)!
|
||||
if (slotsForType.has(definition.tagSlot)) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: `Duplicate slot '${definition.tagSlot}' for field type '${definition.fieldType}'`,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
slotsForType.add(definition.tagSlot)
|
||||
}
|
||||
|
||||
const now = new Date()
|
||||
const createdDefinitions: (typeof knowledgeBaseTagDefinitions.$inferSelect)[] = []
|
||||
|
||||
// Get existing definitions
|
||||
const existingDefinitions = await db
|
||||
.select()
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
|
||||
// Group by field type for validation
|
||||
const existingByFieldType = new Map<string, number>()
|
||||
for (const def of existingDefinitions) {
|
||||
existingByFieldType.set(def.fieldType, (existingByFieldType.get(def.fieldType) || 0) + 1)
|
||||
}
|
||||
|
||||
// Validate we don't exceed limits per field type
|
||||
const newByFieldType = new Map<string, number>()
|
||||
for (const definition of validatedData.definitions) {
|
||||
// Skip validation for edit operations - they don't create new slots
|
||||
if (definition._originalDisplayName) {
|
||||
continue
|
||||
}
|
||||
|
||||
const existingTagNames = new Set(
|
||||
existingDefinitions
|
||||
.filter((def) => def.fieldType === definition.fieldType)
|
||||
.map((def) => def.displayName)
|
||||
)
|
||||
|
||||
if (!existingTagNames.has(definition.displayName)) {
|
||||
newByFieldType.set(
|
||||
definition.fieldType,
|
||||
(newByFieldType.get(definition.fieldType) || 0) + 1
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
for (const [fieldType, newCount] of newByFieldType.entries()) {
|
||||
const existingCount = existingByFieldType.get(fieldType) || 0
|
||||
const maxSlots = getMaxSlotsForFieldType(fieldType)
|
||||
|
||||
if (existingCount + newCount > maxSlots) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: `Cannot create ${newCount} new '${fieldType}' tags. Knowledge base already has ${existingCount} '${fieldType}' tag definitions. Maximum is ${maxSlots} per field type.`,
|
||||
},
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Use transaction to ensure consistency
|
||||
await db.transaction(async (tx) => {
|
||||
// Create maps for lookups
|
||||
const existingByName = new Map(existingDefinitions.map((def) => [def.displayName, def]))
|
||||
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot, def]))
|
||||
|
||||
// Process each definition
|
||||
for (const definition of validatedData.definitions) {
|
||||
if (definition._originalDisplayName) {
|
||||
// This is an EDIT operation - find by original name and update
|
||||
const originalDefinition = existingByName.get(definition._originalDisplayName)
|
||||
|
||||
if (originalDefinition) {
|
||||
logger.info(
|
||||
`[${requestId}] Editing tag definition: ${definition._originalDisplayName} -> ${definition.displayName} (slot ${originalDefinition.tagSlot})`
|
||||
)
|
||||
|
||||
await tx
|
||||
.update(knowledgeBaseTagDefinitions)
|
||||
.set({
|
||||
displayName: definition.displayName,
|
||||
fieldType: definition.fieldType,
|
||||
updatedAt: now,
|
||||
})
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, originalDefinition.id))
|
||||
|
||||
createdDefinitions.push({
|
||||
...originalDefinition,
|
||||
displayName: definition.displayName,
|
||||
fieldType: definition.fieldType,
|
||||
updatedAt: now,
|
||||
})
|
||||
continue
|
||||
}
|
||||
logger.warn(
|
||||
`[${requestId}] Could not find original definition for: ${definition._originalDisplayName}`
|
||||
)
|
||||
}
|
||||
|
||||
// Regular create/update logic
|
||||
const existingByDisplayName = existingByName.get(definition.displayName)
|
||||
|
||||
if (existingByDisplayName) {
|
||||
// Display name exists - UPDATE operation
|
||||
logger.info(
|
||||
`[${requestId}] Updating existing tag definition: ${definition.displayName} (slot ${existingByDisplayName.tagSlot})`
|
||||
)
|
||||
|
||||
await tx
|
||||
.update(knowledgeBaseTagDefinitions)
|
||||
.set({
|
||||
fieldType: definition.fieldType,
|
||||
updatedAt: now,
|
||||
})
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, existingByDisplayName.id))
|
||||
|
||||
createdDefinitions.push({
|
||||
...existingByDisplayName,
|
||||
fieldType: definition.fieldType,
|
||||
updatedAt: now,
|
||||
})
|
||||
} else {
|
||||
// Display name doesn't exist - CREATE operation
|
||||
const targetSlot = await getNextAvailableSlot(
|
||||
knowledgeBaseId,
|
||||
definition.fieldType,
|
||||
existingBySlot
|
||||
)
|
||||
|
||||
if (!targetSlot) {
|
||||
logger.error(
|
||||
`[${requestId}] No available slots for new tag definition: ${definition.displayName}`
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Creating new tag definition: ${definition.displayName} -> ${targetSlot}`
|
||||
)
|
||||
|
||||
const newDefinition = {
|
||||
id: randomUUID(),
|
||||
knowledgeBaseId,
|
||||
tagSlot: targetSlot as any,
|
||||
displayName: definition.displayName,
|
||||
fieldType: definition.fieldType,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
await tx.insert(knowledgeBaseTagDefinitions).values(newDefinition)
|
||||
existingBySlot.set(targetSlot as any, newDefinition)
|
||||
createdDefinitions.push(newDefinition as any)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(`[${requestId}] Created/updated ${createdDefinitions.length} tag definitions`)
|
||||
const result = await createOrUpdateTagDefinitionsBulk(knowledgeBaseId, bulkData, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: createdDefinitions,
|
||||
data: {
|
||||
created: result.created,
|
||||
updated: result.updated,
|
||||
errors: result.errors,
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
@@ -459,10 +171,19 @@ export async function DELETE(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has write access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseWriteAccess(knowledgeBaseId, session.user.id)
|
||||
// Verify document exists and user has write access
|
||||
const accessCheck = await checkDocumentWriteAccess(knowledgeBaseId, documentId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
if (accessCheck.notFound) {
|
||||
logger.warn(
|
||||
`[${requestId}] ${accessCheck.reason}: KB=${knowledgeBaseId}, Doc=${documentId}`
|
||||
)
|
||||
return NextResponse.json({ error: accessCheck.reason }, { status: 404 })
|
||||
}
|
||||
logger.warn(
|
||||
`[${requestId}] User ${session.user.id} attempted unauthorized document write access: ${accessCheck.reason}`
|
||||
)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
if (action === 'cleanup') {
|
||||
@@ -478,13 +199,12 @@ export async function DELETE(
|
||||
// Delete all tag definitions (original behavior)
|
||||
logger.info(`[${requestId}] Deleting all tag definitions for KB ${knowledgeBaseId}`)
|
||||
|
||||
const result = await db
|
||||
.delete(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
const deletedCount = await deleteAllTagDefinitions(knowledgeBaseId, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: 'Tag definitions deleted successfully',
|
||||
data: { deleted: deletedCount },
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error with tag definitions operation`, error)
|
||||
|
||||
@@ -24,6 +24,19 @@ vi.mock('@/app/api/knowledge/utils', () => ({
|
||||
processDocumentAsync: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/knowledge/documents/service', () => ({
|
||||
getDocuments: vi.fn(),
|
||||
createSingleDocument: vi.fn(),
|
||||
createDocumentRecords: vi.fn(),
|
||||
processDocumentsWithQueue: vi.fn(),
|
||||
getProcessingConfig: vi.fn(),
|
||||
bulkDocumentOperation: vi.fn(),
|
||||
updateDocument: vi.fn(),
|
||||
deleteDocument: vi.fn(),
|
||||
markDocumentAsFailedTimeout: vi.fn(),
|
||||
retryDocumentProcessing: vi.fn(),
|
||||
}))
|
||||
|
||||
mockDrizzleOrm()
|
||||
mockConsoleLogger()
|
||||
|
||||
@@ -72,7 +85,6 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
}
|
||||
}
|
||||
})
|
||||
// Clear all mocks - they will be set up in individual tests
|
||||
}
|
||||
|
||||
beforeEach(async () => {
|
||||
@@ -96,6 +108,7 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
|
||||
it('should retrieve documents successfully for authenticated user', async () => {
|
||||
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { getDocuments } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
|
||||
@@ -103,11 +116,15 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock the count query (first query)
|
||||
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
|
||||
|
||||
// Mock the documents query (second query)
|
||||
mockDbChain.offset.mockResolvedValue([mockDocument])
|
||||
vi.mocked(getDocuments).mockResolvedValue({
|
||||
documents: [mockDocument],
|
||||
pagination: {
|
||||
total: 1,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
hasMore: false,
|
||||
},
|
||||
})
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
@@ -118,12 +135,22 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.documents).toHaveLength(1)
|
||||
expect(data.data.documents[0].id).toBe('doc-123')
|
||||
expect(mockDbChain.select).toHaveBeenCalled()
|
||||
expect(vi.mocked(checkKnowledgeBaseAccess)).toHaveBeenCalledWith('kb-123', 'user-123')
|
||||
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
|
||||
'kb-123',
|
||||
{
|
||||
includeDisabled: false,
|
||||
search: undefined,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
},
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should filter disabled documents by default', async () => {
|
||||
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { getDocuments } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
|
||||
@@ -131,22 +158,36 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock the count query (first query)
|
||||
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
|
||||
|
||||
// Mock the documents query (second query)
|
||||
mockDbChain.offset.mockResolvedValue([mockDocument])
|
||||
vi.mocked(getDocuments).mockResolvedValue({
|
||||
documents: [mockDocument],
|
||||
pagination: {
|
||||
total: 1,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
hasMore: false,
|
||||
},
|
||||
})
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
const response = await GET(req, { params: mockParams })
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(mockDbChain.where).toHaveBeenCalled()
|
||||
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
|
||||
'kb-123',
|
||||
{
|
||||
includeDisabled: false,
|
||||
search: undefined,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
},
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should include disabled documents when requested', async () => {
|
||||
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { getDocuments } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
|
||||
@@ -154,11 +195,15 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock the count query (first query)
|
||||
mockDbChain.where.mockResolvedValueOnce([{ count: 1 }])
|
||||
|
||||
// Mock the documents query (second query)
|
||||
mockDbChain.offset.mockResolvedValue([mockDocument])
|
||||
vi.mocked(getDocuments).mockResolvedValue({
|
||||
documents: [mockDocument],
|
||||
pagination: {
|
||||
total: 1,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
hasMore: false,
|
||||
},
|
||||
})
|
||||
|
||||
const url = 'http://localhost:3000/api/knowledge/kb-123/documents?includeDisabled=true'
|
||||
const req = new Request(url, { method: 'GET' }) as any
|
||||
@@ -167,6 +212,16 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
const response = await GET(req, { params: mockParams })
|
||||
|
||||
expect(response.status).toBe(200)
|
||||
expect(vi.mocked(getDocuments)).toHaveBeenCalledWith(
|
||||
'kb-123',
|
||||
{
|
||||
includeDisabled: true,
|
||||
search: undefined,
|
||||
limit: 50,
|
||||
offset: 0,
|
||||
},
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthenticated user', async () => {
|
||||
@@ -216,13 +271,14 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
|
||||
it('should handle database errors', async () => {
|
||||
const { checkKnowledgeBaseAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { getDocuments } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseAccess).mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
mockDbChain.orderBy.mockRejectedValue(new Error('Database error'))
|
||||
vi.mocked(getDocuments).mockRejectedValue(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
@@ -245,13 +301,35 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
|
||||
it('should create single document successfully', async () => {
|
||||
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { createSingleDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
mockDbChain.values.mockResolvedValue(undefined)
|
||||
|
||||
const createdDocument = {
|
||||
id: 'doc-123',
|
||||
knowledgeBaseId: 'kb-123',
|
||||
filename: validDocumentData.filename,
|
||||
fileUrl: validDocumentData.fileUrl,
|
||||
fileSize: validDocumentData.fileSize,
|
||||
mimeType: validDocumentData.mimeType,
|
||||
chunkCount: 0,
|
||||
tokenCount: 0,
|
||||
characterCount: 0,
|
||||
enabled: true,
|
||||
uploadedAt: new Date(),
|
||||
tag1: null,
|
||||
tag2: null,
|
||||
tag3: null,
|
||||
tag4: null,
|
||||
tag5: null,
|
||||
tag6: null,
|
||||
tag7: null,
|
||||
}
|
||||
vi.mocked(createSingleDocument).mockResolvedValue(createdDocument)
|
||||
|
||||
const req = createMockRequest('POST', validDocumentData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
@@ -262,7 +340,11 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.filename).toBe(validDocumentData.filename)
|
||||
expect(data.data.fileUrl).toBe(validDocumentData.fileUrl)
|
||||
expect(mockDbChain.insert).toHaveBeenCalled()
|
||||
expect(vi.mocked(createSingleDocument)).toHaveBeenCalledWith(
|
||||
validDocumentData,
|
||||
'kb-123',
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should validate single document data', async () => {
|
||||
@@ -320,9 +402,9 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
}
|
||||
|
||||
it('should create bulk documents successfully', async () => {
|
||||
const { checkKnowledgeBaseWriteAccess, processDocumentAsync } = await import(
|
||||
'@/app/api/knowledge/utils'
|
||||
)
|
||||
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { createDocumentRecords, processDocumentsWithQueue, getProcessingConfig } =
|
||||
await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
|
||||
@@ -330,17 +412,31 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock transaction to return the created documents
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
const mockTx = {
|
||||
insert: vi.fn().mockReturnValue({
|
||||
values: vi.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
}
|
||||
return await callback(mockTx)
|
||||
})
|
||||
const createdDocuments = [
|
||||
{
|
||||
documentId: 'doc-1',
|
||||
filename: 'doc1.pdf',
|
||||
fileUrl: 'https://example.com/doc1.pdf',
|
||||
fileSize: 1024,
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
{
|
||||
documentId: 'doc-2',
|
||||
filename: 'doc2.pdf',
|
||||
fileUrl: 'https://example.com/doc2.pdf',
|
||||
fileSize: 2048,
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
]
|
||||
|
||||
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
|
||||
vi.mocked(createDocumentRecords).mockResolvedValue(createdDocuments)
|
||||
vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined)
|
||||
vi.mocked(getProcessingConfig).mockReturnValue({
|
||||
maxConcurrentDocuments: 8,
|
||||
batchSize: 20,
|
||||
delayBetweenBatches: 100,
|
||||
delayBetweenDocuments: 0,
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validBulkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
@@ -352,7 +448,12 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
expect(data.data.total).toBe(2)
|
||||
expect(data.data.documentsCreated).toHaveLength(2)
|
||||
expect(data.data.processingMethod).toBe('background')
|
||||
expect(mockDbChain.transaction).toHaveBeenCalled()
|
||||
expect(vi.mocked(createDocumentRecords)).toHaveBeenCalledWith(
|
||||
validBulkData.documents,
|
||||
'kb-123',
|
||||
expect.any(String)
|
||||
)
|
||||
expect(vi.mocked(processDocumentsWithQueue)).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should validate bulk document data', async () => {
|
||||
@@ -394,9 +495,9 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
})
|
||||
|
||||
it('should handle processing errors gracefully', async () => {
|
||||
const { checkKnowledgeBaseWriteAccess, processDocumentAsync } = await import(
|
||||
'@/app/api/knowledge/utils'
|
||||
)
|
||||
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { createDocumentRecords, processDocumentsWithQueue, getProcessingConfig } =
|
||||
await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
|
||||
@@ -404,26 +505,30 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
// Mock transaction to succeed but processing to fail
|
||||
mockDbChain.transaction.mockImplementation(async (callback) => {
|
||||
const mockTx = {
|
||||
insert: vi.fn().mockReturnValue({
|
||||
values: vi.fn().mockResolvedValue(undefined),
|
||||
}),
|
||||
}
|
||||
return await callback(mockTx)
|
||||
})
|
||||
const createdDocuments = [
|
||||
{
|
||||
documentId: 'doc-1',
|
||||
filename: 'doc1.pdf',
|
||||
fileUrl: 'https://example.com/doc1.pdf',
|
||||
fileSize: 1024,
|
||||
mimeType: 'application/pdf',
|
||||
},
|
||||
]
|
||||
|
||||
// Don't reject the promise - the processing is async and catches errors internally
|
||||
vi.mocked(processDocumentAsync).mockResolvedValue(undefined)
|
||||
vi.mocked(createDocumentRecords).mockResolvedValue(createdDocuments)
|
||||
vi.mocked(processDocumentsWithQueue).mockResolvedValue(undefined)
|
||||
vi.mocked(getProcessingConfig).mockReturnValue({
|
||||
maxConcurrentDocuments: 8,
|
||||
batchSize: 20,
|
||||
delayBetweenBatches: 100,
|
||||
delayBetweenDocuments: 0,
|
||||
})
|
||||
|
||||
const req = createMockRequest('POST', validBulkData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
const response = await POST(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
// The endpoint should still return success since documents are created
|
||||
// and processing happens asynchronously
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
})
|
||||
@@ -485,13 +590,14 @@ describe('Knowledge Base Documents API Route', () => {
|
||||
|
||||
it('should handle database errors during creation', async () => {
|
||||
const { checkKnowledgeBaseWriteAccess } = await import('@/app/api/knowledge/utils')
|
||||
const { createSingleDocument } = await import('@/lib/knowledge/documents/service')
|
||||
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
vi.mocked(checkKnowledgeBaseWriteAccess).mockResolvedValue({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
mockDbChain.values.mockRejectedValue(new Error('Database error'))
|
||||
vi.mocked(createSingleDocument).mockRejectedValue(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('POST', validDocumentData)
|
||||
const { POST } = await import('@/app/api/knowledge/[id]/documents/route')
|
||||
|
||||
@@ -1,279 +1,22 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, desc, eq, inArray, isNull, sql } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getSlotsForFieldType } from '@/lib/constants/knowledge'
|
||||
import {
|
||||
bulkDocumentOperation,
|
||||
createDocumentRecords,
|
||||
createSingleDocument,
|
||||
getDocuments,
|
||||
getProcessingConfig,
|
||||
processDocumentsWithQueue,
|
||||
} from '@/lib/knowledge/documents/service'
|
||||
import type { DocumentSortField, SortOrder } from '@/lib/knowledge/documents/types'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUserId } from '@/app/api/auth/oauth/utils'
|
||||
import {
|
||||
checkKnowledgeBaseAccess,
|
||||
checkKnowledgeBaseWriteAccess,
|
||||
processDocumentAsync,
|
||||
} from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
|
||||
|
||||
const logger = createLogger('DocumentsAPI')
|
||||
|
||||
const PROCESSING_CONFIG = {
|
||||
maxConcurrentDocuments: 3,
|
||||
batchSize: 5,
|
||||
delayBetweenBatches: 1000,
|
||||
delayBetweenDocuments: 500,
|
||||
}
|
||||
|
||||
// Helper function to get the next available slot for a knowledge base and field type
|
||||
async function getNextAvailableSlot(
|
||||
knowledgeBaseId: string,
|
||||
fieldType: string,
|
||||
existingBySlot?: Map<string, any>
|
||||
): Promise<string | null> {
|
||||
let usedSlots: Set<string>
|
||||
|
||||
if (existingBySlot) {
|
||||
// Use provided map if available (for performance in batch operations)
|
||||
// Filter by field type
|
||||
usedSlots = new Set(
|
||||
Array.from(existingBySlot.entries())
|
||||
.filter(([_, def]) => def.fieldType === fieldType)
|
||||
.map(([slot, _]) => slot)
|
||||
)
|
||||
} else {
|
||||
// Query database for existing tag definitions of the same field type
|
||||
const existingDefinitions = await db
|
||||
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
|
||||
)
|
||||
)
|
||||
|
||||
usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot))
|
||||
}
|
||||
|
||||
// Find the first available slot for this field type
|
||||
const availableSlots = getSlotsForFieldType(fieldType)
|
||||
for (const slot of availableSlots) {
|
||||
if (!usedSlots.has(slot)) {
|
||||
return slot
|
||||
}
|
||||
}
|
||||
|
||||
return null // No available slots for this field type
|
||||
}
|
||||
|
||||
// Helper function to process structured document tags
|
||||
async function processDocumentTags(
|
||||
knowledgeBaseId: string,
|
||||
tagData: Array<{ tagName: string; fieldType: string; value: string }>,
|
||||
requestId: string
|
||||
): Promise<Record<string, string | null>> {
|
||||
const result: Record<string, string | null> = {}
|
||||
|
||||
// Initialize all text tag slots to null (only text type is supported currently)
|
||||
const textSlots = getSlotsForFieldType('text')
|
||||
textSlots.forEach((slot) => {
|
||||
result[slot] = null
|
||||
})
|
||||
|
||||
if (!Array.isArray(tagData) || tagData.length === 0) {
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// Get existing tag definitions
|
||||
const existingDefinitions = await db
|
||||
.select()
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
|
||||
const existingByName = new Map(existingDefinitions.map((def) => [def.displayName, def]))
|
||||
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot, def]))
|
||||
|
||||
// Process each tag
|
||||
for (const tag of tagData) {
|
||||
if (!tag.tagName?.trim() || !tag.value?.trim()) continue
|
||||
|
||||
const tagName = tag.tagName.trim()
|
||||
const fieldType = tag.fieldType
|
||||
const value = tag.value.trim()
|
||||
|
||||
let targetSlot: string | null = null
|
||||
|
||||
// Check if tag definition already exists
|
||||
const existingDef = existingByName.get(tagName)
|
||||
if (existingDef) {
|
||||
targetSlot = existingDef.tagSlot
|
||||
} else {
|
||||
// Find next available slot using the helper function
|
||||
targetSlot = await getNextAvailableSlot(knowledgeBaseId, fieldType, existingBySlot)
|
||||
|
||||
// Create new tag definition if we have a slot
|
||||
if (targetSlot) {
|
||||
const newDefinition = {
|
||||
id: randomUUID(),
|
||||
knowledgeBaseId,
|
||||
tagSlot: targetSlot as any,
|
||||
displayName: tagName,
|
||||
fieldType,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
|
||||
await db.insert(knowledgeBaseTagDefinitions).values(newDefinition)
|
||||
existingBySlot.set(targetSlot as any, newDefinition)
|
||||
|
||||
logger.info(`[${requestId}] Created tag definition: ${tagName} -> ${targetSlot}`)
|
||||
}
|
||||
}
|
||||
|
||||
// Assign value to the slot
|
||||
if (targetSlot) {
|
||||
result[targetSlot] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error processing document tags:`, error)
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
async function processDocumentsWithConcurrencyControl(
|
||||
createdDocuments: Array<{
|
||||
documentId: string
|
||||
filename: string
|
||||
fileUrl: string
|
||||
fileSize: number
|
||||
mimeType: string
|
||||
}>,
|
||||
knowledgeBaseId: string,
|
||||
processingOptions: {
|
||||
chunkSize: number
|
||||
minCharactersPerChunk: number
|
||||
recipe: string
|
||||
lang: string
|
||||
chunkOverlap: number
|
||||
},
|
||||
requestId: string
|
||||
): Promise<void> {
|
||||
const totalDocuments = createdDocuments.length
|
||||
const batches = []
|
||||
|
||||
for (let i = 0; i < totalDocuments; i += PROCESSING_CONFIG.batchSize) {
|
||||
batches.push(createdDocuments.slice(i, i + PROCESSING_CONFIG.batchSize))
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Processing ${totalDocuments} documents in ${batches.length} batches`)
|
||||
|
||||
for (const [batchIndex, batch] of batches.entries()) {
|
||||
logger.info(
|
||||
`[${requestId}] Starting batch ${batchIndex + 1}/${batches.length} with ${batch.length} documents`
|
||||
)
|
||||
|
||||
await processBatchWithConcurrency(batch, knowledgeBaseId, processingOptions, requestId)
|
||||
|
||||
if (batchIndex < batches.length - 1) {
|
||||
await new Promise((resolve) => setTimeout(resolve, PROCESSING_CONFIG.delayBetweenBatches))
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`[${requestId}] Completed processing initiation for all ${totalDocuments} documents`)
|
||||
}
|
||||
|
||||
async function processBatchWithConcurrency(
|
||||
batch: Array<{
|
||||
documentId: string
|
||||
filename: string
|
||||
fileUrl: string
|
||||
fileSize: number
|
||||
mimeType: string
|
||||
}>,
|
||||
knowledgeBaseId: string,
|
||||
processingOptions: {
|
||||
chunkSize: number
|
||||
minCharactersPerChunk: number
|
||||
recipe: string
|
||||
lang: string
|
||||
chunkOverlap: number
|
||||
},
|
||||
requestId: string
|
||||
): Promise<void> {
|
||||
const semaphore = new Array(PROCESSING_CONFIG.maxConcurrentDocuments).fill(0)
|
||||
const processingPromises = batch.map(async (doc, index) => {
|
||||
if (index > 0) {
|
||||
await new Promise((resolve) =>
|
||||
setTimeout(resolve, index * PROCESSING_CONFIG.delayBetweenDocuments)
|
||||
)
|
||||
}
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
const checkSlot = () => {
|
||||
const availableIndex = semaphore.findIndex((slot) => slot === 0)
|
||||
if (availableIndex !== -1) {
|
||||
semaphore[availableIndex] = 1
|
||||
resolve()
|
||||
} else {
|
||||
setTimeout(checkSlot, 100)
|
||||
}
|
||||
}
|
||||
checkSlot()
|
||||
})
|
||||
|
||||
try {
|
||||
logger.info(`[${requestId}] Starting processing for document: ${doc.filename}`)
|
||||
|
||||
await processDocumentAsync(
|
||||
knowledgeBaseId,
|
||||
doc.documentId,
|
||||
{
|
||||
filename: doc.filename,
|
||||
fileUrl: doc.fileUrl,
|
||||
fileSize: doc.fileSize,
|
||||
mimeType: doc.mimeType,
|
||||
},
|
||||
processingOptions
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Successfully initiated processing for document: ${doc.filename}`)
|
||||
} catch (error: unknown) {
|
||||
logger.error(`[${requestId}] Failed to process document: ${doc.filename}`, {
|
||||
documentId: doc.documentId,
|
||||
filename: doc.filename,
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
})
|
||||
|
||||
try {
|
||||
await db
|
||||
.update(document)
|
||||
.set({
|
||||
processingStatus: 'failed',
|
||||
processingError:
|
||||
error instanceof Error ? error.message : 'Failed to initiate processing',
|
||||
processingCompletedAt: new Date(),
|
||||
})
|
||||
.where(eq(document.id, doc.documentId))
|
||||
} catch (dbError: unknown) {
|
||||
logger.error(
|
||||
`[${requestId}] Failed to update document status for failed document: ${doc.documentId}`,
|
||||
dbError
|
||||
)
|
||||
}
|
||||
} finally {
|
||||
const slotIndex = semaphore.findIndex((slot) => slot === 1)
|
||||
if (slotIndex !== -1) {
|
||||
semaphore[slotIndex] = 0
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
await Promise.allSettled(processingPromises)
|
||||
}
|
||||
|
||||
const CreateDocumentSchema = z.object({
|
||||
filename: z.string().min(1, 'Filename is required'),
|
||||
fileUrl: z.string().url('File URL must be valid'),
|
||||
@@ -337,83 +80,50 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
|
||||
const url = new URL(req.url)
|
||||
const includeDisabled = url.searchParams.get('includeDisabled') === 'true'
|
||||
const search = url.searchParams.get('search')
|
||||
const search = url.searchParams.get('search') || undefined
|
||||
const limit = Number.parseInt(url.searchParams.get('limit') || '50')
|
||||
const offset = Number.parseInt(url.searchParams.get('offset') || '0')
|
||||
const sortByParam = url.searchParams.get('sortBy')
|
||||
const sortOrderParam = url.searchParams.get('sortOrder')
|
||||
|
||||
// Build where conditions
|
||||
const whereConditions = [
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
isNull(document.deletedAt),
|
||||
// Validate sort parameters
|
||||
const validSortFields: DocumentSortField[] = [
|
||||
'filename',
|
||||
'fileSize',
|
||||
'tokenCount',
|
||||
'chunkCount',
|
||||
'uploadedAt',
|
||||
'processingStatus',
|
||||
]
|
||||
const validSortOrders: SortOrder[] = ['asc', 'desc']
|
||||
|
||||
// Filter out disabled documents unless specifically requested
|
||||
if (!includeDisabled) {
|
||||
whereConditions.push(eq(document.enabled, true))
|
||||
}
|
||||
const sortBy =
|
||||
sortByParam && validSortFields.includes(sortByParam as DocumentSortField)
|
||||
? (sortByParam as DocumentSortField)
|
||||
: undefined
|
||||
const sortOrder =
|
||||
sortOrderParam && validSortOrders.includes(sortOrderParam as SortOrder)
|
||||
? (sortOrderParam as SortOrder)
|
||||
: undefined
|
||||
|
||||
// Add search condition if provided
|
||||
if (search) {
|
||||
whereConditions.push(
|
||||
// Search in filename
|
||||
sql`LOWER(${document.filename}) LIKE LOWER(${`%${search}%`})`
|
||||
)
|
||||
}
|
||||
|
||||
// Get total count for pagination
|
||||
const totalResult = await db
|
||||
.select({ count: sql<number>`COUNT(*)` })
|
||||
.from(document)
|
||||
.where(and(...whereConditions))
|
||||
|
||||
const total = totalResult[0]?.count || 0
|
||||
const hasMore = offset + limit < total
|
||||
|
||||
const documents = await db
|
||||
.select({
|
||||
id: document.id,
|
||||
filename: document.filename,
|
||||
fileUrl: document.fileUrl,
|
||||
fileSize: document.fileSize,
|
||||
mimeType: document.mimeType,
|
||||
chunkCount: document.chunkCount,
|
||||
tokenCount: document.tokenCount,
|
||||
characterCount: document.characterCount,
|
||||
processingStatus: document.processingStatus,
|
||||
processingStartedAt: document.processingStartedAt,
|
||||
processingCompletedAt: document.processingCompletedAt,
|
||||
processingError: document.processingError,
|
||||
enabled: document.enabled,
|
||||
uploadedAt: document.uploadedAt,
|
||||
// Include tags in response
|
||||
tag1: document.tag1,
|
||||
tag2: document.tag2,
|
||||
tag3: document.tag3,
|
||||
tag4: document.tag4,
|
||||
tag5: document.tag5,
|
||||
tag6: document.tag6,
|
||||
tag7: document.tag7,
|
||||
})
|
||||
.from(document)
|
||||
.where(and(...whereConditions))
|
||||
.orderBy(desc(document.uploadedAt))
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Retrieved ${documents.length} documents (${offset}-${offset + documents.length} of ${total}) for knowledge base ${knowledgeBaseId}`
|
||||
const result = await getDocuments(
|
||||
knowledgeBaseId,
|
||||
{
|
||||
includeDisabled,
|
||||
search,
|
||||
limit,
|
||||
offset,
|
||||
...(sortBy && { sortBy }),
|
||||
...(sortOrder && { sortOrder }),
|
||||
},
|
||||
requestId
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
documents,
|
||||
pagination: {
|
||||
total,
|
||||
limit,
|
||||
offset,
|
||||
hasMore,
|
||||
},
|
||||
documents: result.documents,
|
||||
pagination: result.pagination,
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
@@ -462,80 +172,21 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if this is a bulk operation
|
||||
if (body.bulk === true) {
|
||||
// Handle bulk processing (replaces process-documents endpoint)
|
||||
try {
|
||||
const validatedData = BulkCreateDocumentsSchema.parse(body)
|
||||
|
||||
const createdDocuments = await db.transaction(async (tx) => {
|
||||
const documentPromises = validatedData.documents.map(async (docData) => {
|
||||
const documentId = randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
// Process documentTagsData if provided (for knowledge base block)
|
||||
let processedTags: Record<string, string | null> = {
|
||||
tag1: null,
|
||||
tag2: null,
|
||||
tag3: null,
|
||||
tag4: null,
|
||||
tag5: null,
|
||||
tag6: null,
|
||||
tag7: null,
|
||||
}
|
||||
|
||||
if (docData.documentTagsData) {
|
||||
try {
|
||||
const tagData = JSON.parse(docData.documentTagsData)
|
||||
if (Array.isArray(tagData)) {
|
||||
processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`[${requestId}] Failed to parse documentTagsData for bulk document:`,
|
||||
error
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const newDocument = {
|
||||
id: documentId,
|
||||
knowledgeBaseId,
|
||||
filename: docData.filename,
|
||||
fileUrl: docData.fileUrl,
|
||||
fileSize: docData.fileSize,
|
||||
mimeType: docData.mimeType,
|
||||
chunkCount: 0,
|
||||
tokenCount: 0,
|
||||
characterCount: 0,
|
||||
processingStatus: 'pending' as const,
|
||||
enabled: true,
|
||||
uploadedAt: now,
|
||||
// Use processed tags if available, otherwise fall back to individual tag fields
|
||||
tag1: processedTags.tag1 || docData.tag1 || null,
|
||||
tag2: processedTags.tag2 || docData.tag2 || null,
|
||||
tag3: processedTags.tag3 || docData.tag3 || null,
|
||||
tag4: processedTags.tag4 || docData.tag4 || null,
|
||||
tag5: processedTags.tag5 || docData.tag5 || null,
|
||||
tag6: processedTags.tag6 || docData.tag6 || null,
|
||||
tag7: processedTags.tag7 || docData.tag7 || null,
|
||||
}
|
||||
|
||||
await tx.insert(document).values(newDocument)
|
||||
logger.info(
|
||||
`[${requestId}] Document record created: ${documentId} for file: ${docData.filename}`
|
||||
)
|
||||
return { documentId, ...docData }
|
||||
})
|
||||
|
||||
return await Promise.all(documentPromises)
|
||||
})
|
||||
const createdDocuments = await createDocumentRecords(
|
||||
validatedData.documents,
|
||||
knowledgeBaseId,
|
||||
requestId
|
||||
)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Starting controlled async processing of ${createdDocuments.length} documents`
|
||||
)
|
||||
|
||||
processDocumentsWithConcurrencyControl(
|
||||
processDocumentsWithQueue(
|
||||
createdDocuments,
|
||||
knowledgeBaseId,
|
||||
validatedData.processingOptions,
|
||||
@@ -555,9 +206,9 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
})),
|
||||
processingMethod: 'background',
|
||||
processingConfig: {
|
||||
maxConcurrentDocuments: PROCESSING_CONFIG.maxConcurrentDocuments,
|
||||
batchSize: PROCESSING_CONFIG.batchSize,
|
||||
totalBatches: Math.ceil(createdDocuments.length / PROCESSING_CONFIG.batchSize),
|
||||
maxConcurrentDocuments: getProcessingConfig().maxConcurrentDocuments,
|
||||
batchSize: getProcessingConfig().batchSize,
|
||||
totalBatches: Math.ceil(createdDocuments.length / getProcessingConfig().batchSize),
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -578,52 +229,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
try {
|
||||
const validatedData = CreateDocumentSchema.parse(body)
|
||||
|
||||
const documentId = randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
// Process structured tag data if provided
|
||||
let processedTags: Record<string, string | null> = {
|
||||
tag1: validatedData.tag1 || null,
|
||||
tag2: validatedData.tag2 || null,
|
||||
tag3: validatedData.tag3 || null,
|
||||
tag4: validatedData.tag4 || null,
|
||||
tag5: validatedData.tag5 || null,
|
||||
tag6: validatedData.tag6 || null,
|
||||
tag7: validatedData.tag7 || null,
|
||||
}
|
||||
|
||||
if (validatedData.documentTagsData) {
|
||||
try {
|
||||
const tagData = JSON.parse(validatedData.documentTagsData)
|
||||
if (Array.isArray(tagData)) {
|
||||
// Process structured tag data and create tag definitions
|
||||
processedTags = await processDocumentTags(knowledgeBaseId, tagData, requestId)
|
||||
}
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Failed to parse documentTagsData:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
const newDocument = {
|
||||
id: documentId,
|
||||
knowledgeBaseId,
|
||||
filename: validatedData.filename,
|
||||
fileUrl: validatedData.fileUrl,
|
||||
fileSize: validatedData.fileSize,
|
||||
mimeType: validatedData.mimeType,
|
||||
chunkCount: 0,
|
||||
tokenCount: 0,
|
||||
characterCount: 0,
|
||||
enabled: true,
|
||||
uploadedAt: now,
|
||||
...processedTags,
|
||||
}
|
||||
|
||||
await db.insert(document).values(newDocument)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Document created: ${documentId} in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
const newDocument = await createSingleDocument(validatedData, knowledgeBaseId, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
@@ -649,7 +255,7 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
}
|
||||
|
||||
export async function PATCH(req: NextRequest, { params }: { params: Promise<{ id: string }> }) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
const requestId = randomUUID().slice(0, 8)
|
||||
const { id: knowledgeBaseId } = await params
|
||||
|
||||
try {
|
||||
@@ -678,89 +284,28 @@ export async function PATCH(req: NextRequest, { params }: { params: Promise<{ id
|
||||
const validatedData = BulkUpdateDocumentsSchema.parse(body)
|
||||
const { operation, documentIds } = validatedData
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Starting bulk ${operation} operation on ${documentIds.length} documents in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
|
||||
// Verify all documents belong to this knowledge base and user has access
|
||||
const documentsToUpdate = await db
|
||||
.select({
|
||||
id: document.id,
|
||||
enabled: document.enabled,
|
||||
})
|
||||
.from(document)
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
inArray(document.id, documentIds),
|
||||
isNull(document.deletedAt)
|
||||
)
|
||||
)
|
||||
|
||||
if (documentsToUpdate.length === 0) {
|
||||
return NextResponse.json({ error: 'No valid documents found to update' }, { status: 404 })
|
||||
}
|
||||
|
||||
if (documentsToUpdate.length !== documentIds.length) {
|
||||
logger.warn(
|
||||
`[${requestId}] Some documents not found or don't belong to knowledge base. Requested: ${documentIds.length}, Found: ${documentsToUpdate.length}`
|
||||
)
|
||||
}
|
||||
|
||||
// Perform the bulk operation
|
||||
let updateResult: Array<{ id: string; enabled?: boolean; deletedAt?: Date | null }>
|
||||
let successCount: number
|
||||
|
||||
if (operation === 'delete') {
|
||||
// Handle bulk soft delete
|
||||
updateResult = await db
|
||||
.update(document)
|
||||
.set({
|
||||
deletedAt: new Date(),
|
||||
})
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
inArray(document.id, documentIds),
|
||||
isNull(document.deletedAt)
|
||||
)
|
||||
)
|
||||
.returning({ id: document.id, deletedAt: document.deletedAt })
|
||||
|
||||
successCount = updateResult.length
|
||||
} else {
|
||||
// Handle bulk enable/disable
|
||||
const enabled = operation === 'enable'
|
||||
|
||||
updateResult = await db
|
||||
.update(document)
|
||||
.set({
|
||||
enabled,
|
||||
})
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
inArray(document.id, documentIds),
|
||||
isNull(document.deletedAt)
|
||||
)
|
||||
)
|
||||
.returning({ id: document.id, enabled: document.enabled })
|
||||
|
||||
successCount = updateResult.length
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Bulk ${operation} operation completed: ${successCount} documents updated in knowledge base ${knowledgeBaseId}`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
try {
|
||||
const result = await bulkDocumentOperation(
|
||||
knowledgeBaseId,
|
||||
operation,
|
||||
successCount,
|
||||
updatedDocuments: updateResult,
|
||||
},
|
||||
})
|
||||
documentIds,
|
||||
requestId
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
operation,
|
||||
successCount: result.successCount,
|
||||
updatedDocuments: result.updatedDocuments,
|
||||
},
|
||||
})
|
||||
} catch (error) {
|
||||
if (error instanceof Error && error.message === 'No valid documents found to update') {
|
||||
return NextResponse.json({ error: 'No valid documents found to update' }, { status: 404 })
|
||||
}
|
||||
throw error
|
||||
}
|
||||
} catch (validationError) {
|
||||
if (validationError instanceof z.ZodError) {
|
||||
logger.warn(`[${requestId}] Invalid bulk operation data`, {
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getMaxSlotsForFieldType, getSlotsForFieldType } from '@/lib/constants/knowledge'
|
||||
import { getNextAvailableSlot, getTagDefinitions } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('NextAvailableSlotAPI')
|
||||
|
||||
@@ -31,51 +28,36 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has read access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get available slots for this field type
|
||||
const availableSlots = getSlotsForFieldType(fieldType)
|
||||
const maxSlots = getMaxSlotsForFieldType(fieldType)
|
||||
// Get existing definitions once and reuse
|
||||
const existingDefinitions = await getTagDefinitions(knowledgeBaseId)
|
||||
const usedSlots = existingDefinitions
|
||||
.filter((def) => def.fieldType === fieldType)
|
||||
.map((def) => def.tagSlot)
|
||||
|
||||
// Get existing tag definitions to find used slots for this field type
|
||||
const existingDefinitions = await db
|
||||
.select({ tagSlot: knowledgeBaseTagDefinitions.tagSlot })
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.fieldType, fieldType)
|
||||
)
|
||||
)
|
||||
|
||||
const usedSlots = new Set(existingDefinitions.map((def) => def.tagSlot as string))
|
||||
|
||||
// Find the first available slot for this field type
|
||||
let nextAvailableSlot: string | null = null
|
||||
for (const slot of availableSlots) {
|
||||
if (!usedSlots.has(slot)) {
|
||||
nextAvailableSlot = slot
|
||||
break
|
||||
}
|
||||
}
|
||||
// Create a map for efficient lookup and pass to avoid redundant query
|
||||
const existingBySlot = new Map(existingDefinitions.map((def) => [def.tagSlot as string, def]))
|
||||
const nextAvailableSlot = await getNextAvailableSlot(knowledgeBaseId, fieldType, existingBySlot)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Next available slot for fieldType ${fieldType}: ${nextAvailableSlot}`
|
||||
)
|
||||
|
||||
const result = {
|
||||
nextAvailableSlot,
|
||||
fieldType,
|
||||
usedSlots,
|
||||
totalSlots: 7,
|
||||
availableSlots: nextAvailableSlot ? 7 - usedSlots.length : 0,
|
||||
}
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: {
|
||||
nextAvailableSlot,
|
||||
fieldType,
|
||||
usedSlots: Array.from(usedSlots),
|
||||
totalSlots: maxSlots,
|
||||
availableSlots: maxSlots - usedSlots.size,
|
||||
},
|
||||
data: result,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error getting next available slot`, error)
|
||||
|
||||
@@ -16,9 +16,26 @@ mockKnowledgeSchemas()
|
||||
mockDrizzleOrm()
|
||||
mockConsoleLogger()
|
||||
|
||||
vi.mock('@/lib/knowledge/service', () => ({
|
||||
getKnowledgeBaseById: vi.fn(),
|
||||
updateKnowledgeBase: vi.fn(),
|
||||
deleteKnowledgeBase: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/api/knowledge/utils', () => ({
|
||||
checkKnowledgeBaseAccess: vi.fn(),
|
||||
checkKnowledgeBaseWriteAccess: vi.fn(),
|
||||
}))
|
||||
|
||||
describe('Knowledge Base By ID API Route', () => {
|
||||
const mockAuth$ = mockAuth()
|
||||
|
||||
let mockGetKnowledgeBaseById: any
|
||||
let mockUpdateKnowledgeBase: any
|
||||
let mockDeleteKnowledgeBase: any
|
||||
let mockCheckKnowledgeBaseAccess: any
|
||||
let mockCheckKnowledgeBaseWriteAccess: any
|
||||
|
||||
const mockDbChain = {
|
||||
select: vi.fn().mockReturnThis(),
|
||||
from: vi.fn().mockReturnThis(),
|
||||
@@ -62,6 +79,15 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
vi.stubGlobal('crypto', {
|
||||
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),
|
||||
})
|
||||
|
||||
const knowledgeService = await import('@/lib/knowledge/service')
|
||||
const knowledgeUtils = await import('@/app/api/knowledge/utils')
|
||||
|
||||
mockGetKnowledgeBaseById = knowledgeService.getKnowledgeBaseById as any
|
||||
mockUpdateKnowledgeBase = knowledgeService.updateKnowledgeBase as any
|
||||
mockDeleteKnowledgeBase = knowledgeService.deleteKnowledgeBase as any
|
||||
mockCheckKnowledgeBaseAccess = knowledgeUtils.checkKnowledgeBaseAccess as any
|
||||
mockCheckKnowledgeBaseWriteAccess = knowledgeUtils.checkKnowledgeBaseWriteAccess as any
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
@@ -74,9 +100,12 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
it('should retrieve knowledge base successfully for authenticated user', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([mockKnowledgeBase])
|
||||
mockGetKnowledgeBaseById.mockResolvedValueOnce(mockKnowledgeBase)
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -87,7 +116,8 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.id).toBe('kb-123')
|
||||
expect(data.data.name).toBe('Test Knowledge Base')
|
||||
expect(mockDbChain.select).toHaveBeenCalled()
|
||||
expect(mockCheckKnowledgeBaseAccess).toHaveBeenCalledWith('kb-123', 'user-123')
|
||||
expect(mockGetKnowledgeBaseById).toHaveBeenCalledWith('kb-123')
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthenticated user', async () => {
|
||||
@@ -105,7 +135,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
it('should return not found for non-existent knowledge base', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([])
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
|
||||
hasAccess: false,
|
||||
notFound: true,
|
||||
})
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -119,7 +152,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
it('should return unauthorized for knowledge base owned by different user', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'different-user' }])
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
|
||||
hasAccess: false,
|
||||
notFound: false,
|
||||
})
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -130,9 +166,29 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
expect(data.error).toBe('Unauthorized')
|
||||
})
|
||||
|
||||
it('should return not found when service returns null', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockCheckKnowledgeBaseAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockGetKnowledgeBaseById.mockResolvedValueOnce(null)
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/route')
|
||||
const response = await GET(req, { params: mockParams })
|
||||
const data = await response.json()
|
||||
|
||||
expect(response.status).toBe(404)
|
||||
expect(data.error).toBe('Knowledge base not found')
|
||||
})
|
||||
|
||||
it('should handle database errors', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
mockDbChain.limit.mockRejectedValueOnce(new Error('Database error'))
|
||||
|
||||
mockCheckKnowledgeBaseAccess.mockRejectedValueOnce(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('GET')
|
||||
const { GET } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -156,13 +212,13 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockDbChain.where.mockResolvedValueOnce(undefined)
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ ...mockKnowledgeBase, ...validUpdateData }])
|
||||
const updatedKnowledgeBase = { ...mockKnowledgeBase, ...validUpdateData }
|
||||
mockUpdateKnowledgeBase.mockResolvedValueOnce(updatedKnowledgeBase)
|
||||
|
||||
const req = createMockRequest('PUT', validUpdateData)
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -172,7 +228,16 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.name).toBe('Updated Knowledge Base')
|
||||
expect(mockDbChain.update).toHaveBeenCalled()
|
||||
expect(mockCheckKnowledgeBaseWriteAccess).toHaveBeenCalledWith('kb-123', 'user-123')
|
||||
expect(mockUpdateKnowledgeBase).toHaveBeenCalledWith(
|
||||
'kb-123',
|
||||
{
|
||||
name: validUpdateData.name,
|
||||
description: validUpdateData.description,
|
||||
chunkingConfig: undefined,
|
||||
},
|
||||
expect.any(String)
|
||||
)
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthenticated user', async () => {
|
||||
@@ -192,8 +257,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: false,
|
||||
notFound: true,
|
||||
})
|
||||
|
||||
const req = createMockRequest('PUT', validUpdateData)
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -209,8 +276,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
const invalidData = {
|
||||
name: '',
|
||||
@@ -229,9 +298,13 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
it('should handle database errors during update', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
// Mock successful write access check
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockDbChain.where.mockRejectedValueOnce(new Error('Database error'))
|
||||
mockUpdateKnowledgeBase.mockRejectedValueOnce(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('PUT', validUpdateData)
|
||||
const { PUT } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -251,10 +324,12 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockDbChain.where.mockResolvedValueOnce(undefined)
|
||||
mockDeleteKnowledgeBase.mockResolvedValueOnce(undefined)
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -264,7 +339,8 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
expect(response.status).toBe(200)
|
||||
expect(data.success).toBe(true)
|
||||
expect(data.data.message).toBe('Knowledge base deleted successfully')
|
||||
expect(mockDbChain.update).toHaveBeenCalled()
|
||||
expect(mockCheckKnowledgeBaseWriteAccess).toHaveBeenCalledWith('kb-123', 'user-123')
|
||||
expect(mockDeleteKnowledgeBase).toHaveBeenCalledWith('kb-123', expect.any(String))
|
||||
})
|
||||
|
||||
it('should return unauthorized for unauthenticated user', async () => {
|
||||
@@ -284,8 +360,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: false,
|
||||
notFound: true,
|
||||
})
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -301,8 +379,10 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
|
||||
resetMocks()
|
||||
|
||||
mockDbChain.where.mockReturnValueOnce(mockDbChain) // Return this to continue chain
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'different-user' }])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: false,
|
||||
notFound: false,
|
||||
})
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
|
||||
@@ -316,9 +396,12 @@ describe('Knowledge Base By ID API Route', () => {
|
||||
it('should handle database errors during delete', async () => {
|
||||
mockAuth$.mockAuthenticatedUser()
|
||||
|
||||
mockDbChain.limit.mockResolvedValueOnce([{ id: 'kb-123', userId: 'user-123' }])
|
||||
mockCheckKnowledgeBaseWriteAccess.mockResolvedValueOnce({
|
||||
hasAccess: true,
|
||||
knowledgeBase: { id: 'kb-123', userId: 'user-123' },
|
||||
})
|
||||
|
||||
mockDbChain.where.mockRejectedValueOnce(new Error('Database error'))
|
||||
mockDeleteKnowledgeBase.mockRejectedValueOnce(new Error('Database error'))
|
||||
|
||||
const req = createMockRequest('DELETE')
|
||||
const { DELETE } = await import('@/app/api/knowledge/[id]/route')
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import { and, eq, isNull } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import {
|
||||
deleteKnowledgeBase,
|
||||
getKnowledgeBaseById,
|
||||
updateKnowledgeBase,
|
||||
} from '@/lib/knowledge/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess, checkKnowledgeBaseWriteAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { knowledgeBase } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeBaseByIdAPI')
|
||||
|
||||
@@ -48,13 +50,9 @@ export async function GET(_req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const knowledgeBases = await db
|
||||
.select()
|
||||
.from(knowledgeBase)
|
||||
.where(and(eq(knowledgeBase.id, id), isNull(knowledgeBase.deletedAt)))
|
||||
.limit(1)
|
||||
const knowledgeBaseData = await getKnowledgeBaseById(id)
|
||||
|
||||
if (knowledgeBases.length === 0) {
|
||||
if (!knowledgeBaseData) {
|
||||
return NextResponse.json({ error: 'Knowledge base not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
@@ -62,7 +60,7 @@ export async function GET(_req: NextRequest, { params }: { params: Promise<{ id:
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: knowledgeBases[0],
|
||||
data: knowledgeBaseData,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error fetching knowledge base`, error)
|
||||
@@ -99,42 +97,21 @@ export async function PUT(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
try {
|
||||
const validatedData = UpdateKnowledgeBaseSchema.parse(body)
|
||||
|
||||
const updateData: any = {
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
|
||||
if (validatedData.name !== undefined) updateData.name = validatedData.name
|
||||
if (validatedData.description !== undefined)
|
||||
updateData.description = validatedData.description
|
||||
if (validatedData.workspaceId !== undefined)
|
||||
updateData.workspaceId = validatedData.workspaceId
|
||||
|
||||
// Handle embedding model and dimension together to ensure consistency
|
||||
if (
|
||||
validatedData.embeddingModel !== undefined ||
|
||||
validatedData.embeddingDimension !== undefined
|
||||
) {
|
||||
updateData.embeddingModel = 'text-embedding-3-small'
|
||||
updateData.embeddingDimension = 1536
|
||||
}
|
||||
|
||||
if (validatedData.chunkingConfig !== undefined)
|
||||
updateData.chunkingConfig = validatedData.chunkingConfig
|
||||
|
||||
await db.update(knowledgeBase).set(updateData).where(eq(knowledgeBase.id, id))
|
||||
|
||||
// Fetch the updated knowledge base
|
||||
const updatedKnowledgeBase = await db
|
||||
.select()
|
||||
.from(knowledgeBase)
|
||||
.where(eq(knowledgeBase.id, id))
|
||||
.limit(1)
|
||||
const updatedKnowledgeBase = await updateKnowledgeBase(
|
||||
id,
|
||||
{
|
||||
name: validatedData.name,
|
||||
description: validatedData.description,
|
||||
chunkingConfig: validatedData.chunkingConfig,
|
||||
},
|
||||
requestId
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Knowledge base updated: ${id} for user ${session.user.id}`)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
data: updatedKnowledgeBase[0],
|
||||
data: updatedKnowledgeBase,
|
||||
})
|
||||
} catch (validationError) {
|
||||
if (validationError instanceof z.ZodError) {
|
||||
@@ -178,14 +155,7 @@ export async function DELETE(_req: NextRequest, { params }: { params: Promise<{
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Soft delete by setting deletedAt timestamp
|
||||
await db
|
||||
.update(knowledgeBase)
|
||||
.set({
|
||||
deletedAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(eq(knowledgeBase.id, id))
|
||||
await deleteKnowledgeBase(id, requestId)
|
||||
|
||||
logger.info(`[${requestId}] Knowledge base deleted: ${id} for user ${session.user.id}`)
|
||||
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq, isNotNull } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { deleteTagDefinition } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding, knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
@@ -29,87 +27,16 @@ export async function DELETE(
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get the tag definition to find which slot it uses
|
||||
const tagDefinition = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.id, tagId),
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (tagDefinition.length === 0) {
|
||||
return NextResponse.json({ error: 'Tag definition not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
const tagDef = tagDefinition[0]
|
||||
|
||||
// Delete the tag definition and clear all document tags in a transaction
|
||||
await db.transaction(async (tx) => {
|
||||
logger.info(`[${requestId}] Starting transaction to delete ${tagDef.tagSlot}`)
|
||||
|
||||
try {
|
||||
// Clear the tag from documents that actually have this tag set
|
||||
logger.info(`[${requestId}] Clearing tag from documents...`)
|
||||
await tx
|
||||
.update(document)
|
||||
.set({ [tagDef.tagSlot]: null })
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
isNotNull(document[tagDef.tagSlot as keyof typeof document.$inferSelect])
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Documents updated successfully`)
|
||||
|
||||
// Clear the tag from embeddings that actually have this tag set
|
||||
logger.info(`[${requestId}] Clearing tag from embeddings...`)
|
||||
await tx
|
||||
.update(embedding)
|
||||
.set({ [tagDef.tagSlot]: null })
|
||||
.where(
|
||||
and(
|
||||
eq(embedding.knowledgeBaseId, knowledgeBaseId),
|
||||
isNotNull(embedding[tagDef.tagSlot as keyof typeof embedding.$inferSelect])
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(`[${requestId}] Embeddings updated successfully`)
|
||||
|
||||
// Delete the tag definition
|
||||
logger.info(`[${requestId}] Deleting tag definition...`)
|
||||
await tx
|
||||
.delete(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.id, tagId))
|
||||
|
||||
logger.info(`[${requestId}] Tag definition deleted successfully`)
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error in transaction:`, error)
|
||||
throw error
|
||||
}
|
||||
})
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Successfully deleted tag definition ${tagDef.displayName} (${tagDef.tagSlot})`
|
||||
)
|
||||
const deletedTag = await deleteTagDefinition(tagId, requestId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
message: `Tag definition "${tagDef.displayName}" deleted successfully`,
|
||||
message: `Tag definition "${deletedTag.displayName}" deleted successfully`,
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error(`[${requestId}] Error deleting tag definition`, error)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { SUPPORTED_FIELD_TYPES } from '@/lib/knowledge/consts'
|
||||
import { createTagDefinition, getTagDefinitions } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
@@ -24,25 +24,12 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get tag definitions for the knowledge base
|
||||
const tagDefinitions = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
fieldType: knowledgeBaseTagDefinitions.fieldType,
|
||||
createdAt: knowledgeBaseTagDefinitions.createdAt,
|
||||
updatedAt: knowledgeBaseTagDefinitions.updatedAt,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
.orderBy(knowledgeBaseTagDefinitions.tagSlot)
|
||||
const tagDefinitions = await getTagDefinitions(knowledgeBaseId)
|
||||
|
||||
logger.info(`[${requestId}] Retrieved ${tagDefinitions.length} tag definitions`)
|
||||
|
||||
@@ -69,68 +56,43 @@ export async function POST(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
const body = await req.json()
|
||||
const { tagSlot, displayName, fieldType } = body
|
||||
|
||||
if (!tagSlot || !displayName || !fieldType) {
|
||||
return NextResponse.json(
|
||||
{ error: 'tagSlot, displayName, and fieldType are required' },
|
||||
{ status: 400 }
|
||||
)
|
||||
}
|
||||
const CreateTagDefinitionSchema = z.object({
|
||||
tagSlot: z.string().min(1, 'Tag slot is required'),
|
||||
displayName: z.string().min(1, 'Display name is required'),
|
||||
fieldType: z.enum(SUPPORTED_FIELD_TYPES as [string, ...string[]], {
|
||||
errorMap: () => ({ message: 'Invalid field type' }),
|
||||
}),
|
||||
})
|
||||
|
||||
// Check if tag slot is already used
|
||||
const existingTag = await db
|
||||
.select()
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.tagSlot, tagSlot)
|
||||
let validatedData
|
||||
try {
|
||||
validatedData = CreateTagDefinitionSchema.parse(body)
|
||||
} catch (error) {
|
||||
if (error instanceof z.ZodError) {
|
||||
return NextResponse.json(
|
||||
{ error: 'Invalid request data', details: error.errors },
|
||||
{ status: 400 }
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (existingTag.length > 0) {
|
||||
return NextResponse.json({ error: 'Tag slot is already in use' }, { status: 409 })
|
||||
}
|
||||
throw error
|
||||
}
|
||||
|
||||
// Check if display name is already used
|
||||
const existingName = await db
|
||||
.select()
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(
|
||||
and(
|
||||
eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId),
|
||||
eq(knowledgeBaseTagDefinitions.displayName, displayName)
|
||||
)
|
||||
)
|
||||
.limit(1)
|
||||
|
||||
if (existingName.length > 0) {
|
||||
return NextResponse.json({ error: 'Tag name is already in use' }, { status: 409 })
|
||||
}
|
||||
|
||||
// Create the new tag definition
|
||||
const newTagDefinition = {
|
||||
id: randomUUID(),
|
||||
knowledgeBaseId,
|
||||
tagSlot,
|
||||
displayName,
|
||||
fieldType,
|
||||
createdAt: new Date(),
|
||||
updatedAt: new Date(),
|
||||
}
|
||||
|
||||
await db.insert(knowledgeBaseTagDefinitions).values(newTagDefinition)
|
||||
|
||||
logger.info(`[${requestId}] Successfully created tag definition ${displayName} (${tagSlot})`)
|
||||
const newTagDefinition = await createTagDefinition(
|
||||
{
|
||||
knowledgeBaseId,
|
||||
tagSlot: validatedData.tagSlot,
|
||||
displayName: validatedData.displayName,
|
||||
fieldType: validatedData.fieldType,
|
||||
},
|
||||
requestId
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import { randomUUID } from 'crypto'
|
||||
import { and, eq, isNotNull } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { getTagUsage } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
|
||||
export const dynamic = 'force-dynamic'
|
||||
|
||||
@@ -24,57 +22,15 @@ export async function GET(req: NextRequest, { params }: { params: Promise<{ id:
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check if user has access to the knowledge base
|
||||
const accessCheck = await checkKnowledgeBaseAccess(knowledgeBaseId, session.user.id)
|
||||
if (!accessCheck.hasAccess) {
|
||||
return NextResponse.json({ error: 'Forbidden' }, { status: 403 })
|
||||
}
|
||||
|
||||
// Get all tag definitions for the knowledge base
|
||||
const tagDefinitions = await db
|
||||
.select({
|
||||
id: knowledgeBaseTagDefinitions.id,
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, knowledgeBaseId))
|
||||
|
||||
// Get usage statistics for each tag definition
|
||||
const usageStats = await Promise.all(
|
||||
tagDefinitions.map(async (tagDef) => {
|
||||
// Count documents using this tag slot
|
||||
const tagSlotColumn = tagDef.tagSlot as keyof typeof document.$inferSelect
|
||||
|
||||
const documentsWithTag = await db
|
||||
.select({
|
||||
id: document.id,
|
||||
filename: document.filename,
|
||||
[tagDef.tagSlot]: document[tagSlotColumn as keyof typeof document.$inferSelect] as any,
|
||||
})
|
||||
.from(document)
|
||||
.where(
|
||||
and(
|
||||
eq(document.knowledgeBaseId, knowledgeBaseId),
|
||||
isNotNull(document[tagSlotColumn as keyof typeof document.$inferSelect])
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
tagName: tagDef.displayName,
|
||||
tagSlot: tagDef.tagSlot,
|
||||
documentCount: documentsWithTag.length,
|
||||
documents: documentsWithTag.map((doc) => ({
|
||||
id: doc.id,
|
||||
name: doc.filename,
|
||||
tagValue: doc[tagDef.tagSlot],
|
||||
})),
|
||||
}
|
||||
})
|
||||
)
|
||||
const usageStats = await getTagUsage(knowledgeBaseId, requestId)
|
||||
|
||||
logger.info(
|
||||
`[${requestId}] Retrieved usage statistics for ${tagDefinitions.length} tag definitions`
|
||||
`[${requestId}] Retrieved usage statistics for ${usageStats.length} tag definitions`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import { and, count, eq, isNotNull, isNull, or } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createKnowledgeBase, getKnowledgeBases } from '@/lib/knowledge/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUserEntityPermissions } from '@/lib/permissions/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, knowledgeBase, permissions } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeBaseAPI')
|
||||
|
||||
@@ -41,60 +38,10 @@ export async function GET(req: NextRequest) {
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
// Check for workspace filtering
|
||||
const { searchParams } = new URL(req.url)
|
||||
const workspaceId = searchParams.get('workspaceId')
|
||||
|
||||
// Get knowledge bases that user can access through direct ownership OR workspace permissions
|
||||
const knowledgeBasesWithCounts = await db
|
||||
.select({
|
||||
id: knowledgeBase.id,
|
||||
name: knowledgeBase.name,
|
||||
description: knowledgeBase.description,
|
||||
tokenCount: knowledgeBase.tokenCount,
|
||||
embeddingModel: knowledgeBase.embeddingModel,
|
||||
embeddingDimension: knowledgeBase.embeddingDimension,
|
||||
chunkingConfig: knowledgeBase.chunkingConfig,
|
||||
createdAt: knowledgeBase.createdAt,
|
||||
updatedAt: knowledgeBase.updatedAt,
|
||||
workspaceId: knowledgeBase.workspaceId,
|
||||
docCount: count(document.id),
|
||||
})
|
||||
.from(knowledgeBase)
|
||||
.leftJoin(
|
||||
document,
|
||||
and(eq(document.knowledgeBaseId, knowledgeBase.id), isNull(document.deletedAt))
|
||||
)
|
||||
.leftJoin(
|
||||
permissions,
|
||||
and(
|
||||
eq(permissions.entityType, 'workspace'),
|
||||
eq(permissions.entityId, knowledgeBase.workspaceId),
|
||||
eq(permissions.userId, session.user.id)
|
||||
)
|
||||
)
|
||||
.where(
|
||||
and(
|
||||
isNull(knowledgeBase.deletedAt),
|
||||
workspaceId
|
||||
? // When filtering by workspace
|
||||
or(
|
||||
// Knowledge bases belonging to the specified workspace (user must have workspace permissions)
|
||||
and(eq(knowledgeBase.workspaceId, workspaceId), isNotNull(permissions.userId)),
|
||||
// Fallback: User-owned knowledge bases without workspace (legacy)
|
||||
and(eq(knowledgeBase.userId, session.user.id), isNull(knowledgeBase.workspaceId))
|
||||
)
|
||||
: // When not filtering by workspace, use original logic
|
||||
or(
|
||||
// User owns the knowledge base directly
|
||||
eq(knowledgeBase.userId, session.user.id),
|
||||
// User has permissions on the knowledge base's workspace
|
||||
isNotNull(permissions.userId)
|
||||
)
|
||||
)
|
||||
)
|
||||
.groupBy(knowledgeBase.id)
|
||||
.orderBy(knowledgeBase.createdAt)
|
||||
const knowledgeBasesWithCounts = await getKnowledgeBases(session.user.id, workspaceId)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
@@ -121,49 +68,16 @@ export async function POST(req: NextRequest) {
|
||||
try {
|
||||
const validatedData = CreateKnowledgeBaseSchema.parse(body)
|
||||
|
||||
// If creating in a workspace, check if user has write/admin permissions
|
||||
if (validatedData.workspaceId) {
|
||||
const userPermission = await getUserEntityPermissions(
|
||||
session.user.id,
|
||||
'workspace',
|
||||
validatedData.workspaceId
|
||||
)
|
||||
if (userPermission !== 'write' && userPermission !== 'admin') {
|
||||
logger.warn(
|
||||
`[${requestId}] User ${session.user.id} denied permission to create knowledge base in workspace ${validatedData.workspaceId}`
|
||||
)
|
||||
return NextResponse.json(
|
||||
{ error: 'Insufficient permissions to create knowledge base in this workspace' },
|
||||
{ status: 403 }
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
const id = crypto.randomUUID()
|
||||
const now = new Date()
|
||||
|
||||
const newKnowledgeBase = {
|
||||
id,
|
||||
const createData = {
|
||||
...validatedData,
|
||||
userId: session.user.id,
|
||||
workspaceId: validatedData.workspaceId || null,
|
||||
name: validatedData.name,
|
||||
description: validatedData.description || null,
|
||||
tokenCount: 0,
|
||||
embeddingModel: validatedData.embeddingModel,
|
||||
embeddingDimension: validatedData.embeddingDimension,
|
||||
chunkingConfig: validatedData.chunkingConfig || {
|
||||
maxSize: 1024,
|
||||
minSize: 100,
|
||||
overlap: 200,
|
||||
},
|
||||
docCount: 0,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
|
||||
await db.insert(knowledgeBase).values(newKnowledgeBase)
|
||||
const newKnowledgeBase = await createKnowledgeBase(createData, requestId)
|
||||
|
||||
logger.info(`[${requestId}] Knowledge base created: ${id} for user ${session.user.id}`)
|
||||
logger.info(
|
||||
`[${requestId}] Knowledge base created: ${newKnowledgeBase.id} for user ${session.user.id}`
|
||||
)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
|
||||
@@ -65,12 +65,14 @@ const mockHandleVectorOnlySearch = vi.fn()
|
||||
const mockHandleTagAndVectorSearch = vi.fn()
|
||||
const mockGetQueryStrategy = vi.fn()
|
||||
const mockGenerateSearchEmbedding = vi.fn()
|
||||
const mockGetDocumentNamesByIds = vi.fn()
|
||||
vi.mock('./utils', () => ({
|
||||
handleTagOnlySearch: mockHandleTagOnlySearch,
|
||||
handleVectorOnlySearch: mockHandleVectorOnlySearch,
|
||||
handleTagAndVectorSearch: mockHandleTagAndVectorSearch,
|
||||
getQueryStrategy: mockGetQueryStrategy,
|
||||
generateSearchEmbedding: mockGenerateSearchEmbedding,
|
||||
getDocumentNamesByIds: mockGetDocumentNamesByIds,
|
||||
APIError: class APIError extends Error {
|
||||
public status: number
|
||||
constructor(message: string, status: number) {
|
||||
@@ -146,6 +148,10 @@ describe('Knowledge Search API Route', () => {
|
||||
singleQueryOptimized: true,
|
||||
})
|
||||
mockGenerateSearchEmbedding.mockClear().mockResolvedValue([0.1, 0.2, 0.3, 0.4, 0.5])
|
||||
mockGetDocumentNamesByIds.mockClear().mockResolvedValue({
|
||||
doc1: 'Document 1',
|
||||
doc2: 'Document 2',
|
||||
})
|
||||
|
||||
vi.stubGlobal('crypto', {
|
||||
randomUUID: vi.fn().mockReturnValue('mock-uuid-1234-5678'),
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import { eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { z } from 'zod'
|
||||
import { TAG_SLOTS } from '@/lib/constants/knowledge'
|
||||
import { TAG_SLOTS } from '@/lib/knowledge/consts'
|
||||
import { getDocumentTagDefinitions } from '@/lib/knowledge/tags/service'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { estimateTokenCount } from '@/lib/tokenization/estimators'
|
||||
import { getUserId } from '@/app/api/auth/oauth/utils'
|
||||
import { checkKnowledgeBaseAccess } from '@/app/api/knowledge/utils'
|
||||
import { db } from '@/db'
|
||||
import { knowledgeBaseTagDefinitions } from '@/db/schema'
|
||||
import { calculateCost } from '@/providers/utils'
|
||||
import {
|
||||
generateSearchEmbedding,
|
||||
getDocumentNamesByIds,
|
||||
getQueryStrategy,
|
||||
handleTagAndVectorSearch,
|
||||
handleTagOnlySearch,
|
||||
@@ -79,14 +78,13 @@ export async function POST(request: NextRequest) {
|
||||
? validatedData.knowledgeBaseIds
|
||||
: [validatedData.knowledgeBaseIds]
|
||||
|
||||
// Check access permissions for each knowledge base using proper workspace-based permissions
|
||||
const accessibleKbIds: string[] = []
|
||||
for (const kbId of knowledgeBaseIds) {
|
||||
const accessCheck = await checkKnowledgeBaseAccess(kbId, userId)
|
||||
if (accessCheck.hasAccess) {
|
||||
accessibleKbIds.push(kbId)
|
||||
}
|
||||
}
|
||||
// Check access permissions in parallel for performance
|
||||
const accessChecks = await Promise.all(
|
||||
knowledgeBaseIds.map((kbId) => checkKnowledgeBaseAccess(kbId, userId))
|
||||
)
|
||||
const accessibleKbIds: string[] = knowledgeBaseIds.filter(
|
||||
(_, idx) => accessChecks[idx]?.hasAccess
|
||||
)
|
||||
|
||||
// Map display names to tag slots for filtering
|
||||
let mappedFilters: Record<string, string> = {}
|
||||
@@ -94,13 +92,7 @@ export async function POST(request: NextRequest) {
|
||||
try {
|
||||
// Fetch tag definitions for the first accessible KB (since we're using single KB now)
|
||||
const kbId = accessibleKbIds[0]
|
||||
const tagDefs = await db
|
||||
.select({
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, kbId))
|
||||
const tagDefs = await getDocumentTagDefinitions(kbId)
|
||||
|
||||
logger.debug(`[${requestId}] Found tag definitions:`, tagDefs)
|
||||
logger.debug(`[${requestId}] Original filters:`, validatedData.filters)
|
||||
@@ -145,7 +137,10 @@ export async function POST(request: NextRequest) {
|
||||
|
||||
// Generate query embedding only if query is provided
|
||||
const hasQuery = validatedData.query && validatedData.query.trim().length > 0
|
||||
const queryEmbedding = hasQuery ? await generateSearchEmbedding(validatedData.query!) : null
|
||||
// Start embedding generation early and await when needed
|
||||
const queryEmbeddingPromise = hasQuery
|
||||
? generateSearchEmbedding(validatedData.query!)
|
||||
: Promise.resolve(null)
|
||||
|
||||
// Check if any requested knowledge bases were not accessible
|
||||
const inaccessibleKbIds = knowledgeBaseIds.filter((id) => !accessibleKbIds.includes(id))
|
||||
@@ -173,7 +168,7 @@ export async function POST(request: NextRequest) {
|
||||
// Tag + Vector search
|
||||
logger.debug(`[${requestId}] Executing tag + vector search with filters:`, mappedFilters)
|
||||
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
|
||||
const queryVector = JSON.stringify(queryEmbedding)
|
||||
const queryVector = JSON.stringify(await queryEmbeddingPromise)
|
||||
|
||||
results = await handleTagAndVectorSearch({
|
||||
knowledgeBaseIds: accessibleKbIds,
|
||||
@@ -186,7 +181,7 @@ export async function POST(request: NextRequest) {
|
||||
// Vector-only search
|
||||
logger.debug(`[${requestId}] Executing vector-only search`)
|
||||
const strategy = getQueryStrategy(accessibleKbIds.length, validatedData.topK)
|
||||
const queryVector = JSON.stringify(queryEmbedding)
|
||||
const queryVector = JSON.stringify(await queryEmbeddingPromise)
|
||||
|
||||
results = await handleVectorOnlySearch({
|
||||
knowledgeBaseIds: accessibleKbIds,
|
||||
@@ -221,30 +216,32 @@ export async function POST(request: NextRequest) {
|
||||
}
|
||||
|
||||
// Fetch tag definitions for display name mapping (reuse the same fetch from filtering)
|
||||
const tagDefinitionsMap: Record<string, Record<string, string>> = {}
|
||||
for (const kbId of accessibleKbIds) {
|
||||
try {
|
||||
const tagDefs = await db
|
||||
.select({
|
||||
tagSlot: knowledgeBaseTagDefinitions.tagSlot,
|
||||
displayName: knowledgeBaseTagDefinitions.displayName,
|
||||
const tagDefsResults = await Promise.all(
|
||||
accessibleKbIds.map(async (kbId) => {
|
||||
try {
|
||||
const tagDefs = await getDocumentTagDefinitions(kbId)
|
||||
const map: Record<string, string> = {}
|
||||
tagDefs.forEach((def) => {
|
||||
map[def.tagSlot] = def.displayName
|
||||
})
|
||||
.from(knowledgeBaseTagDefinitions)
|
||||
.where(eq(knowledgeBaseTagDefinitions.knowledgeBaseId, kbId))
|
||||
return { kbId, map }
|
||||
} catch (error) {
|
||||
logger.warn(
|
||||
`[${requestId}] Failed to fetch tag definitions for display mapping:`,
|
||||
error
|
||||
)
|
||||
return { kbId, map: {} as Record<string, string> }
|
||||
}
|
||||
})
|
||||
)
|
||||
const tagDefinitionsMap: Record<string, Record<string, string>> = {}
|
||||
tagDefsResults.forEach(({ kbId, map }) => {
|
||||
tagDefinitionsMap[kbId] = map
|
||||
})
|
||||
|
||||
tagDefinitionsMap[kbId] = {}
|
||||
tagDefs.forEach((def) => {
|
||||
tagDefinitionsMap[kbId][def.tagSlot] = def.displayName
|
||||
})
|
||||
logger.debug(
|
||||
`[${requestId}] Display mapping - KB ${kbId} tag definitions:`,
|
||||
tagDefinitionsMap[kbId]
|
||||
)
|
||||
} catch (error) {
|
||||
logger.warn(`[${requestId}] Failed to fetch tag definitions for display mapping:`, error)
|
||||
tagDefinitionsMap[kbId] = {}
|
||||
}
|
||||
}
|
||||
// Fetch document names for the results
|
||||
const documentIds = results.map((result) => result.documentId)
|
||||
const documentNameMap = await getDocumentNamesByIds(documentIds)
|
||||
|
||||
return NextResponse.json({
|
||||
success: true,
|
||||
@@ -271,11 +268,11 @@ export async function POST(request: NextRequest) {
|
||||
})
|
||||
|
||||
return {
|
||||
id: result.id,
|
||||
content: result.content,
|
||||
documentId: result.documentId,
|
||||
documentName: documentNameMap[result.documentId] || undefined,
|
||||
content: result.content,
|
||||
chunkIndex: result.chunkIndex,
|
||||
tags, // Clean display name mapped tags
|
||||
metadata: tags, // Clean display name mapped tags
|
||||
similarity: hasQuery ? 1 - result.distance : 1, // Perfect similarity for tag-only searches
|
||||
}
|
||||
}),
|
||||
|
||||
@@ -4,15 +4,50 @@
|
||||
*
|
||||
* @vitest-environment node
|
||||
*/
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('drizzle-orm')
|
||||
vi.mock('@/lib/logs/console/logger')
|
||||
vi.mock('@/lib/logs/console/logger', () => ({
|
||||
createLogger: vi.fn(() => ({
|
||||
info: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
})),
|
||||
}))
|
||||
vi.mock('@/db')
|
||||
vi.mock('@/lib/knowledge/documents/utils', () => ({
|
||||
retryWithExponentialBackoff: (fn: any) => fn(),
|
||||
}))
|
||||
|
||||
import { handleTagAndVectorSearch, handleTagOnlySearch, handleVectorOnlySearch } from './utils'
|
||||
vi.stubGlobal(
|
||||
'fetch',
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
vi.mock('@/lib/env', () => ({
|
||||
env: {},
|
||||
isTruthy: (value: string | boolean | number | undefined) =>
|
||||
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
|
||||
}))
|
||||
|
||||
import {
|
||||
generateSearchEmbedding,
|
||||
handleTagAndVectorSearch,
|
||||
handleTagOnlySearch,
|
||||
handleVectorOnlySearch,
|
||||
} from './utils'
|
||||
|
||||
describe('Knowledge Search Utils', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('handleTagOnlySearch', () => {
|
||||
it('should throw error when no filters provided', async () => {
|
||||
const params = {
|
||||
@@ -140,4 +175,251 @@ describe('Knowledge Search Utils', () => {
|
||||
expect(params.distanceThreshold).toBe(0.8)
|
||||
})
|
||||
})
|
||||
|
||||
describe('generateSearchEmbedding', () => {
|
||||
it('should use Azure OpenAI when KB-specific config is provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
|
||||
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
'https://test.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-12-01-preview',
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
'api-key': 'test-azure-key',
|
||||
}),
|
||||
})
|
||||
)
|
||||
expect(result).toEqual([0.1, 0.2, 0.3])
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should fallback to OpenAI when no KB Azure config provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
const result = await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
'https://api.openai.com/v1/embeddings',
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer test-openai-key',
|
||||
}),
|
||||
})
|
||||
)
|
||||
expect(result).toEqual([0.1, 0.2, 0.3])
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should use default API version when not provided in Azure config', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
KB_OPENAI_MODEL_NAME: 'custom-embedding-model',
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect.stringContaining('api-version='),
|
||||
expect.any(Object)
|
||||
)
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should use custom model name when provided in Azure config', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
|
||||
KB_OPENAI_MODEL_NAME: 'custom-embedding-model',
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateSearchEmbedding('test query', 'text-embedding-3-small')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
'https://test.openai.azure.com/openai/deployments/custom-embedding-model/embeddings?api-version=2024-12-01-preview',
|
||||
expect.any(Object)
|
||||
)
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should throw error when no API configuration provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
|
||||
await expect(generateSearchEmbedding('test query')).rejects.toThrow(
|
||||
'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle Azure OpenAI API errors properly', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
|
||||
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 404,
|
||||
statusText: 'Not Found',
|
||||
text: async () => 'Deployment not found',
|
||||
} as any)
|
||||
|
||||
await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should handle OpenAI API errors properly', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 429,
|
||||
statusText: 'Too Many Requests',
|
||||
text: async () => 'Rate limit exceeded',
|
||||
} as any)
|
||||
|
||||
await expect(generateSearchEmbedding('test query')).rejects.toThrow('Embedding API failed')
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should include correct request body for Azure OpenAI', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
|
||||
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateSearchEmbedding('test query')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
input: ['test query'],
|
||||
encoding_format: 'float',
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should include correct request body for OpenAI', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2, 0.3] }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateSearchEmbedding('test query', 'text-embedding-3-small')
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.objectContaining({
|
||||
body: JSON.stringify({
|
||||
input: ['test query'],
|
||||
model: 'text-embedding-3-small',
|
||||
encoding_format: 'float',
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
// Clean up
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,20 +1,32 @@
|
||||
import { and, eq, inArray, sql } from 'drizzle-orm'
|
||||
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { embedding } from '@/db/schema'
|
||||
import { document, embedding } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeSearchUtils')
|
||||
|
||||
export class APIError extends Error {
|
||||
public status: number
|
||||
|
||||
constructor(message: string, status: number) {
|
||||
super(message)
|
||||
this.name = 'APIError'
|
||||
this.status = status
|
||||
export async function getDocumentNamesByIds(
|
||||
documentIds: string[]
|
||||
): Promise<Record<string, string>> {
|
||||
if (documentIds.length === 0) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const uniqueIds = [...new Set(documentIds)]
|
||||
const documents = await db
|
||||
.select({
|
||||
id: document.id,
|
||||
filename: document.filename,
|
||||
})
|
||||
.from(document)
|
||||
.where(inArray(document.id, uniqueIds))
|
||||
|
||||
const documentNameMap: Record<string, string> = {}
|
||||
documents.forEach((doc) => {
|
||||
documentNameMap[doc.id] = doc.filename
|
||||
})
|
||||
|
||||
return documentNameMap
|
||||
}
|
||||
|
||||
export interface SearchResult {
|
||||
@@ -41,61 +53,8 @@ export interface SearchParams {
|
||||
distanceThreshold?: number
|
||||
}
|
||||
|
||||
export async function generateSearchEmbedding(query: string): Promise<number[]> {
|
||||
const openaiApiKey = env.OPENAI_API_KEY
|
||||
if (!openaiApiKey) {
|
||||
throw new Error('OPENAI_API_KEY not configured')
|
||||
}
|
||||
|
||||
try {
|
||||
const embedding = await retryWithExponentialBackoff(
|
||||
async () => {
|
||||
const response = await fetch('https://api.openai.com/v1/embeddings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${openaiApiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: query,
|
||||
model: 'text-embedding-3-small',
|
||||
encoding_format: 'float',
|
||||
}),
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
const error = new APIError(
|
||||
`OpenAI API error: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
response.status
|
||||
)
|
||||
throw error
|
||||
}
|
||||
|
||||
const data = await response.json()
|
||||
|
||||
if (!data.data || !Array.isArray(data.data) || data.data.length === 0) {
|
||||
throw new Error('Invalid response format from OpenAI embeddings API')
|
||||
}
|
||||
|
||||
return data.data[0].embedding
|
||||
},
|
||||
{
|
||||
maxRetries: 5,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 30000,
|
||||
backoffMultiplier: 2,
|
||||
}
|
||||
)
|
||||
|
||||
return embedding
|
||||
} catch (error) {
|
||||
logger.error('Failed to generate search embedding:', error)
|
||||
throw new Error(
|
||||
`Embedding generation failed: ${error instanceof Error ? error.message : 'Unknown error'}`
|
||||
)
|
||||
}
|
||||
}
|
||||
// Use shared embedding utility
|
||||
export { generateSearchEmbedding } from '@/lib/embeddings/utils'
|
||||
|
||||
function getTagFilters(filters: Record<string, string>, embedding: any) {
|
||||
return Object.entries(filters).map(([key, value]) => {
|
||||
|
||||
@@ -21,11 +21,11 @@ vi.mock('@/lib/env', () => ({
|
||||
typeof value === 'string' ? value === 'true' || value === '1' : Boolean(value),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/documents/utils', () => ({
|
||||
vi.mock('@/lib/knowledge/documents/utils', () => ({
|
||||
retryWithExponentialBackoff: (fn: any) => fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/lib/documents/document-processor', () => ({
|
||||
vi.mock('@/lib/knowledge/documents/document-processor', () => ({
|
||||
processDocument: vi.fn().mockResolvedValue({
|
||||
chunks: [
|
||||
{
|
||||
@@ -149,12 +149,12 @@ vi.mock('@/db', () => {
|
||||
}
|
||||
})
|
||||
|
||||
import { generateEmbeddings } from '@/lib/embeddings/utils'
|
||||
import { processDocumentAsync } from '@/lib/knowledge/documents/service'
|
||||
import {
|
||||
checkChunkAccess,
|
||||
checkDocumentAccess,
|
||||
checkKnowledgeBaseAccess,
|
||||
generateEmbeddings,
|
||||
processDocumentAsync,
|
||||
} from '@/app/api/knowledge/utils'
|
||||
|
||||
describe('Knowledge Utils', () => {
|
||||
@@ -252,5 +252,76 @@ describe('Knowledge Utils', () => {
|
||||
|
||||
expect(result.length).toBe(2)
|
||||
})
|
||||
|
||||
it('should use Azure OpenAI when Azure config is provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
AZURE_OPENAI_API_KEY: 'test-azure-key',
|
||||
AZURE_OPENAI_ENDPOINT: 'https://test.openai.azure.com',
|
||||
AZURE_OPENAI_API_VERSION: '2024-12-01-preview',
|
||||
KB_OPENAI_MODEL_NAME: 'text-embedding-ada-002',
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2], index: 0 }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateEmbeddings(['test text'])
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
'https://test.openai.azure.com/openai/deployments/text-embedding-ada-002/embeddings?api-version=2024-12-01-preview',
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
'api-key': 'test-azure-key',
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should fallback to OpenAI when no Azure config provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
Object.assign(env, {
|
||||
OPENAI_API_KEY: 'test-openai-key',
|
||||
})
|
||||
|
||||
const fetchSpy = vi.mocked(fetch)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
data: [{ embedding: [0.1, 0.2], index: 0 }],
|
||||
}),
|
||||
} as any)
|
||||
|
||||
await generateEmbeddings(['test text'])
|
||||
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
'https://api.openai.com/v1/embeddings',
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: 'Bearer test-openai-key',
|
||||
}),
|
||||
})
|
||||
)
|
||||
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
})
|
||||
|
||||
it('should throw error when no API configuration provided', async () => {
|
||||
const { env } = await import('@/lib/env')
|
||||
Object.keys(env).forEach((key) => delete (env as any)[key])
|
||||
|
||||
await expect(generateEmbeddings(['test text'])).rejects.toThrow(
|
||||
'Either OPENAI_API_KEY or Azure OpenAI configuration (AZURE_OPENAI_API_KEY + AZURE_OPENAI_ENDPOINT) must be configured'
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,47 +1,8 @@
|
||||
import crypto from 'crypto'
|
||||
import { and, eq, isNull } from 'drizzle-orm'
|
||||
import { processDocument } from '@/lib/documents/document-processor'
|
||||
import { retryWithExponentialBackoff } from '@/lib/documents/utils'
|
||||
import { env } from '@/lib/env'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { getUserEntityPermissions } from '@/lib/permissions/utils'
|
||||
import { db } from '@/db'
|
||||
import { document, embedding, knowledgeBase } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('KnowledgeUtils')
|
||||
|
||||
// Timeout constants (in milliseconds)
|
||||
const TIMEOUTS = {
|
||||
OVERALL_PROCESSING: 150000, // 150 seconds (2.5 minutes)
|
||||
EMBEDDINGS_API: 60000, // 60 seconds per batch
|
||||
} as const
|
||||
|
||||
class APIError extends Error {
|
||||
public status: number
|
||||
|
||||
constructor(message: string, status: number) {
|
||||
super(message)
|
||||
this.name = 'APIError'
|
||||
this.status = status
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a timeout wrapper for async operations
|
||||
*/
|
||||
function withTimeout<T>(
|
||||
promise: Promise<T>,
|
||||
timeoutMs: number,
|
||||
operation = 'Operation'
|
||||
): Promise<T> {
|
||||
return Promise.race([
|
||||
promise,
|
||||
new Promise<never>((_, reject) =>
|
||||
setTimeout(() => reject(new Error(`${operation} timed out after ${timeoutMs}ms`)), timeoutMs)
|
||||
),
|
||||
])
|
||||
}
|
||||
|
||||
export interface KnowledgeBaseData {
|
||||
id: string
|
||||
userId: string
|
||||
@@ -110,18 +71,6 @@ export interface EmbeddingData {
|
||||
updatedAt: Date
|
||||
}
|
||||
|
||||
interface OpenAIEmbeddingResponse {
|
||||
data: Array<{
|
||||
embedding: number[]
|
||||
index: number
|
||||
}>
|
||||
model: string
|
||||
usage: {
|
||||
prompt_tokens: number
|
||||
total_tokens: number
|
||||
}
|
||||
}
|
||||
|
||||
export interface KnowledgeBaseAccessResult {
|
||||
hasAccess: true
|
||||
knowledgeBase: Pick<KnowledgeBaseData, 'id' | 'userId'>
|
||||
@@ -404,233 +353,3 @@ export async function checkChunkAccess(
|
||||
knowledgeBase: kbAccess.knowledgeBase!,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate embeddings using OpenAI API with retry logic for rate limiting
|
||||
*/
|
||||
export async function generateEmbeddings(
|
||||
texts: string[],
|
||||
embeddingModel = 'text-embedding-3-small'
|
||||
): Promise<number[][]> {
|
||||
const openaiApiKey = env.OPENAI_API_KEY
|
||||
if (!openaiApiKey) {
|
||||
throw new Error('OPENAI_API_KEY not configured')
|
||||
}
|
||||
|
||||
try {
|
||||
const batchSize = 100
|
||||
const allEmbeddings: number[][] = []
|
||||
|
||||
for (let i = 0; i < texts.length; i += batchSize) {
|
||||
const batch = texts.slice(i, i + batchSize)
|
||||
|
||||
logger.info(
|
||||
`Generating embeddings for batch ${Math.floor(i / batchSize) + 1} (${batch.length} texts)`
|
||||
)
|
||||
|
||||
const batchEmbeddings = await retryWithExponentialBackoff(
|
||||
async () => {
|
||||
const controller = new AbortController()
|
||||
const timeoutId = setTimeout(() => controller.abort(), TIMEOUTS.EMBEDDINGS_API)
|
||||
|
||||
try {
|
||||
const response = await fetch('https://api.openai.com/v1/embeddings', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
Authorization: `Bearer ${openaiApiKey}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
input: batch,
|
||||
model: embeddingModel,
|
||||
encoding_format: 'float',
|
||||
}),
|
||||
signal: controller.signal,
|
||||
})
|
||||
|
||||
clearTimeout(timeoutId)
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text()
|
||||
const error = new APIError(
|
||||
`OpenAI API error: ${response.status} ${response.statusText} - ${errorText}`,
|
||||
response.status
|
||||
)
|
||||
throw error
|
||||
}
|
||||
|
||||
const data: OpenAIEmbeddingResponse = await response.json()
|
||||
return data.data.map((item) => item.embedding)
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId)
|
||||
if (error instanceof Error && error.name === 'AbortError') {
|
||||
throw new Error('OpenAI API request timed out')
|
||||
}
|
||||
throw error
|
||||
}
|
||||
},
|
||||
{
|
||||
maxRetries: 5,
|
||||
initialDelayMs: 1000,
|
||||
maxDelayMs: 60000, // Max 1 minute delay for embeddings
|
||||
backoffMultiplier: 2,
|
||||
}
|
||||
)
|
||||
|
||||
allEmbeddings.push(...batchEmbeddings)
|
||||
}
|
||||
|
||||
return allEmbeddings
|
||||
} catch (error) {
|
||||
logger.error('Failed to generate embeddings:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a document asynchronously with full error handling
|
||||
*/
|
||||
export async function processDocumentAsync(
|
||||
knowledgeBaseId: string,
|
||||
documentId: string,
|
||||
docData: {
|
||||
filename: string
|
||||
fileUrl: string
|
||||
fileSize: number
|
||||
mimeType: string
|
||||
},
|
||||
processingOptions: {
|
||||
chunkSize?: number
|
||||
minCharactersPerChunk?: number
|
||||
recipe?: string
|
||||
lang?: string
|
||||
chunkOverlap?: number
|
||||
}
|
||||
): Promise<void> {
|
||||
const startTime = Date.now()
|
||||
try {
|
||||
logger.info(`[${documentId}] Starting document processing: ${docData.filename}`)
|
||||
|
||||
// Set status to processing
|
||||
await db
|
||||
.update(document)
|
||||
.set({
|
||||
processingStatus: 'processing',
|
||||
processingStartedAt: new Date(),
|
||||
processingError: null, // Clear any previous error
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
|
||||
logger.info(`[${documentId}] Status updated to 'processing', starting document processor`)
|
||||
|
||||
// Wrap the entire processing operation with a 5-minute timeout
|
||||
await withTimeout(
|
||||
(async () => {
|
||||
const processed = await processDocument(
|
||||
docData.fileUrl,
|
||||
docData.filename,
|
||||
docData.mimeType,
|
||||
processingOptions.chunkSize || 1000,
|
||||
processingOptions.chunkOverlap || 200,
|
||||
processingOptions.minCharactersPerChunk || 1
|
||||
)
|
||||
|
||||
const now = new Date()
|
||||
|
||||
logger.info(
|
||||
`[${documentId}] Document parsed successfully, generating embeddings for ${processed.chunks.length} chunks`
|
||||
)
|
||||
|
||||
const chunkTexts = processed.chunks.map((chunk) => chunk.text)
|
||||
const embeddings = chunkTexts.length > 0 ? await generateEmbeddings(chunkTexts) : []
|
||||
|
||||
logger.info(`[${documentId}] Embeddings generated, fetching document tags`)
|
||||
|
||||
// Fetch document to get tags
|
||||
const documentRecord = await db
|
||||
.select({
|
||||
tag1: document.tag1,
|
||||
tag2: document.tag2,
|
||||
tag3: document.tag3,
|
||||
tag4: document.tag4,
|
||||
tag5: document.tag5,
|
||||
tag6: document.tag6,
|
||||
tag7: document.tag7,
|
||||
})
|
||||
.from(document)
|
||||
.where(eq(document.id, documentId))
|
||||
.limit(1)
|
||||
|
||||
const documentTags = documentRecord[0] || {}
|
||||
|
||||
logger.info(`[${documentId}] Creating embedding records with tags`)
|
||||
|
||||
const embeddingRecords = processed.chunks.map((chunk, chunkIndex) => ({
|
||||
id: crypto.randomUUID(),
|
||||
knowledgeBaseId,
|
||||
documentId,
|
||||
chunkIndex,
|
||||
chunkHash: crypto.createHash('sha256').update(chunk.text).digest('hex'),
|
||||
content: chunk.text,
|
||||
contentLength: chunk.text.length,
|
||||
tokenCount: Math.ceil(chunk.text.length / 4),
|
||||
embedding: embeddings[chunkIndex] || null,
|
||||
embeddingModel: 'text-embedding-3-small',
|
||||
startOffset: chunk.metadata.startIndex,
|
||||
endOffset: chunk.metadata.endIndex,
|
||||
// Copy tags from document
|
||||
tag1: documentTags.tag1,
|
||||
tag2: documentTags.tag2,
|
||||
tag3: documentTags.tag3,
|
||||
tag4: documentTags.tag4,
|
||||
tag5: documentTags.tag5,
|
||||
tag6: documentTags.tag6,
|
||||
tag7: documentTags.tag7,
|
||||
createdAt: now,
|
||||
updatedAt: now,
|
||||
}))
|
||||
|
||||
await db.transaction(async (tx) => {
|
||||
if (embeddingRecords.length > 0) {
|
||||
await tx.insert(embedding).values(embeddingRecords)
|
||||
}
|
||||
|
||||
await tx
|
||||
.update(document)
|
||||
.set({
|
||||
chunkCount: processed.metadata.chunkCount,
|
||||
tokenCount: processed.metadata.tokenCount,
|
||||
characterCount: processed.metadata.characterCount,
|
||||
processingStatus: 'completed',
|
||||
processingCompletedAt: now,
|
||||
processingError: null,
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
})
|
||||
})(),
|
||||
TIMEOUTS.OVERALL_PROCESSING,
|
||||
'Document processing'
|
||||
)
|
||||
|
||||
const processingTime = Date.now() - startTime
|
||||
logger.info(`[${documentId}] Successfully processed document in ${processingTime}ms`)
|
||||
} catch (error) {
|
||||
const processingTime = Date.now() - startTime
|
||||
logger.error(`[${documentId}] Failed to process document after ${processingTime}ms:`, {
|
||||
error: error instanceof Error ? error.message : 'Unknown error',
|
||||
stack: error instanceof Error ? error.stack : undefined,
|
||||
filename: docData.filename,
|
||||
fileUrl: docData.fileUrl,
|
||||
mimeType: docData.mimeType,
|
||||
})
|
||||
|
||||
await db
|
||||
.update(document)
|
||||
.set({
|
||||
processingStatus: 'failed',
|
||||
processingError: error instanceof Error ? error.message : 'Unknown error',
|
||||
processingCompletedAt: new Date(),
|
||||
})
|
||||
.where(eq(document.id, documentId))
|
||||
}
|
||||
}
|
||||
|
||||
102
apps/sim/app/api/logs/[id]/route.ts
Normal file
102
apps/sim/app/api/logs/[id]/route.ts
Normal file
@@ -0,0 +1,102 @@
|
||||
import { and, eq } from 'drizzle-orm'
|
||||
import { type NextRequest, NextResponse } from 'next/server'
|
||||
import { getSession } from '@/lib/auth'
|
||||
import { createLogger } from '@/lib/logs/console/logger'
|
||||
import { db } from '@/db'
|
||||
import { permissions, workflow, workflowExecutionLogs } from '@/db/schema'
|
||||
|
||||
const logger = createLogger('LogDetailsByIdAPI')
|
||||
|
||||
export const revalidate = 0
|
||||
|
||||
export async function GET(_request: NextRequest, { params }: { params: Promise<{ id: string }> }) {
|
||||
const requestId = crypto.randomUUID().slice(0, 8)
|
||||
|
||||
try {
|
||||
const session = await getSession()
|
||||
if (!session?.user?.id) {
|
||||
logger.warn(`[${requestId}] Unauthorized log details access attempt`)
|
||||
return NextResponse.json({ error: 'Unauthorized' }, { status: 401 })
|
||||
}
|
||||
|
||||
const userId = session.user.id
|
||||
const { id } = await params
|
||||
|
||||
const rows = await db
|
||||
.select({
|
||||
id: workflowExecutionLogs.id,
|
||||
workflowId: workflowExecutionLogs.workflowId,
|
||||
executionId: workflowExecutionLogs.executionId,
|
||||
stateSnapshotId: workflowExecutionLogs.stateSnapshotId,
|
||||
level: workflowExecutionLogs.level,
|
||||
trigger: workflowExecutionLogs.trigger,
|
||||
startedAt: workflowExecutionLogs.startedAt,
|
||||
endedAt: workflowExecutionLogs.endedAt,
|
||||
totalDurationMs: workflowExecutionLogs.totalDurationMs,
|
||||
executionData: workflowExecutionLogs.executionData,
|
||||
cost: workflowExecutionLogs.cost,
|
||||
files: workflowExecutionLogs.files,
|
||||
createdAt: workflowExecutionLogs.createdAt,
|
||||
workflowName: workflow.name,
|
||||
workflowDescription: workflow.description,
|
||||
workflowColor: workflow.color,
|
||||
workflowFolderId: workflow.folderId,
|
||||
workflowUserId: workflow.userId,
|
||||
workflowWorkspaceId: workflow.workspaceId,
|
||||
workflowCreatedAt: workflow.createdAt,
|
||||
workflowUpdatedAt: workflow.updatedAt,
|
||||
})
|
||||
.from(workflowExecutionLogs)
|
||||
.innerJoin(workflow, eq(workflowExecutionLogs.workflowId, workflow.id))
|
||||
.innerJoin(
|
||||
permissions,
|
||||
and(
|
||||
eq(permissions.entityType, 'workspace'),
|
||||
eq(permissions.entityId, workflow.workspaceId),
|
||||
eq(permissions.userId, userId)
|
||||
)
|
||||
)
|
||||
.where(eq(workflowExecutionLogs.id, id))
|
||||
.limit(1)
|
||||
|
||||
const log = rows[0]
|
||||
if (!log) {
|
||||
return NextResponse.json({ error: 'Not found' }, { status: 404 })
|
||||
}
|
||||
|
||||
const workflowSummary = {
|
||||
id: log.workflowId,
|
||||
name: log.workflowName,
|
||||
description: log.workflowDescription,
|
||||
color: log.workflowColor,
|
||||
folderId: log.workflowFolderId,
|
||||
userId: log.workflowUserId,
|
||||
workspaceId: log.workflowWorkspaceId,
|
||||
createdAt: log.workflowCreatedAt,
|
||||
updatedAt: log.workflowUpdatedAt,
|
||||
}
|
||||
|
||||
const response = {
|
||||
id: log.id,
|
||||
workflowId: log.workflowId,
|
||||
executionId: log.executionId,
|
||||
level: log.level,
|
||||
duration: log.totalDurationMs ? `${log.totalDurationMs}ms` : null,
|
||||
trigger: log.trigger,
|
||||
createdAt: log.startedAt.toISOString(),
|
||||
files: log.files || undefined,
|
||||
workflow: workflowSummary,
|
||||
executionData: {
|
||||
totalDuration: log.totalDurationMs,
|
||||
...(log.executionData as any),
|
||||
enhanced: true,
|
||||
},
|
||||
cost: log.cost as any,
|
||||
}
|
||||
|
||||
return NextResponse.json({ data: response })
|
||||
} catch (error: any) {
|
||||
logger.error(`[${requestId}] log details fetch error`, error)
|
||||
return NextResponse.json({ error: error.message }, { status: 500 })
|
||||
}
|
||||
}
|
||||
@@ -99,21 +99,13 @@ export async function GET(request: NextRequest) {
|
||||
executionId: workflowExecutionLogs.executionId,
|
||||
stateSnapshotId: workflowExecutionLogs.stateSnapshotId,
|
||||
level: workflowExecutionLogs.level,
|
||||
message: workflowExecutionLogs.message,
|
||||
trigger: workflowExecutionLogs.trigger,
|
||||
startedAt: workflowExecutionLogs.startedAt,
|
||||
endedAt: workflowExecutionLogs.endedAt,
|
||||
totalDurationMs: workflowExecutionLogs.totalDurationMs,
|
||||
blockCount: workflowExecutionLogs.blockCount,
|
||||
successCount: workflowExecutionLogs.successCount,
|
||||
errorCount: workflowExecutionLogs.errorCount,
|
||||
skippedCount: workflowExecutionLogs.skippedCount,
|
||||
totalCost: workflowExecutionLogs.totalCost,
|
||||
totalInputCost: workflowExecutionLogs.totalInputCost,
|
||||
totalOutputCost: workflowExecutionLogs.totalOutputCost,
|
||||
totalTokens: workflowExecutionLogs.totalTokens,
|
||||
executionData: workflowExecutionLogs.executionData,
|
||||
cost: workflowExecutionLogs.cost,
|
||||
files: workflowExecutionLogs.files,
|
||||
metadata: workflowExecutionLogs.metadata,
|
||||
createdAt: workflowExecutionLogs.createdAt,
|
||||
})
|
||||
.from(workflowExecutionLogs)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user