mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 08:38:09 -05:00
Compare commits
295 Commits
update-blo
...
Search-res
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05afc1e1bb | ||
|
|
a597507680 | ||
|
|
14c223435b | ||
|
|
f6ad194306 | ||
|
|
e5c21ceda9 | ||
|
|
1a4a3e72ec | ||
|
|
584b121a16 | ||
|
|
d67c7a3704 | ||
|
|
dec9dfad9d | ||
|
|
4af2297ca4 | ||
|
|
72e9dd6b37 | ||
|
|
987fe6d3b8 | ||
|
|
3c024be10c | ||
|
|
e508d90fa7 | ||
|
|
91fb4f5c56 | ||
|
|
32c15434d2 | ||
|
|
6fcceeabdb | ||
|
|
b1468f779c | ||
|
|
ac03211c59 | ||
|
|
c35453f41e | ||
|
|
07d732ece7 | ||
|
|
54baa84c28 | ||
|
|
0fb33f6cb7 | ||
|
|
afd809f68a | ||
|
|
a8ae006967 | ||
|
|
b7603f6053 | ||
|
|
c91099a5ef | ||
|
|
c7ae4cfbda | ||
|
|
b5e1c8075e | ||
|
|
d14cecb954 | ||
|
|
9bc72305c5 | ||
|
|
980c90c07c | ||
|
|
006e843461 | ||
|
|
452fe314f6 | ||
|
|
b86b0c758e | ||
|
|
f1053ff8b4 | ||
|
|
7a1176073d | ||
|
|
3145f3d59d | ||
|
|
89e8fe3854 | ||
|
|
6c34f8bd96 | ||
|
|
6ccf62c3a6 | ||
|
|
0217b75d65 | ||
|
|
6a712fd6bb | ||
|
|
f4244e5038 | ||
|
|
9c0e8750b0 | ||
|
|
14ecaf861e | ||
|
|
6d6ed348fa | ||
|
|
1d81c61b77 | ||
|
|
3da5fc2cac | ||
|
|
8d8664a3ce | ||
|
|
d201b653c8 | ||
|
|
f77172ec82 | ||
|
|
ec072ad52a | ||
|
|
9b456612aa | ||
|
|
3d0ec9c52a | ||
|
|
50d08654f8 | ||
|
|
fa871073ca | ||
|
|
ab4f4549d6 | ||
|
|
4492995d6b | ||
|
|
6a736cc60f | ||
|
|
60417545f0 | ||
|
|
4f20f868c1 | ||
|
|
2ed8347430 | ||
|
|
12eb3b2937 | ||
|
|
d5580f8a94 | ||
|
|
5b0ecfd26b | ||
|
|
3b24b0ac0d | ||
|
|
715d425579 | ||
|
|
0eda63fa15 | ||
|
|
4e886bd6e9 | ||
|
|
f3168ea187 | ||
|
|
7fdfffdfcc | ||
|
|
8f4d552909 | ||
|
|
303a55145d | ||
|
|
f54cfee4a7 | ||
|
|
b3bd0f5d54 | ||
|
|
91aa371220 | ||
|
|
1ab19dcc56 | ||
|
|
30a047eac3 | ||
|
|
f2387147c7 | ||
|
|
0f77f931ab | ||
|
|
fdb82eda38 | ||
|
|
53eee63161 | ||
|
|
48d27c91d4 | ||
|
|
39c9e4a76c | ||
|
|
b3fd8bbfb9 | ||
|
|
a4b186cf81 | ||
|
|
a971d59974 | ||
|
|
9db8832d6b | ||
|
|
acb35c3926 | ||
|
|
cd9c4218b0 | ||
|
|
3ccf0138b1 | ||
|
|
19ff8f324d | ||
|
|
f445918abf | ||
|
|
5a4083d542 | ||
|
|
1e66137c7e | ||
|
|
bb13157d18 | ||
|
|
37e6b6f385 | ||
|
|
fabf742601 | ||
|
|
100b667afc | ||
|
|
da10f1a2df | ||
|
|
65ada3fb72 | ||
|
|
bc277acf57 | ||
|
|
52d19d084c | ||
|
|
86a10858bd | ||
|
|
f297e42be4 | ||
|
|
2bbac5714f | ||
|
|
e063f4bcda | ||
|
|
6c64c5b98f | ||
|
|
5530db63de | ||
|
|
4278ae61b0 | ||
|
|
0c78edb592 | ||
|
|
93d3bd3773 | ||
|
|
c1c3fd4982 | ||
|
|
d4b69d864f | ||
|
|
db3284830a | ||
|
|
d16cf6cfeb | ||
|
|
3f3919a843 | ||
|
|
244171d748 | ||
|
|
faa683b6e4 | ||
|
|
9255759e1e | ||
|
|
abbed4051d | ||
|
|
8c5380d4f9 | ||
|
|
cacc6e1f86 | ||
|
|
9616baf695 | ||
|
|
518f196e6b | ||
|
|
a9693b582f | ||
|
|
64fcba3f3a | ||
|
|
a53f3f0e0a | ||
|
|
84cdf189f4 | ||
|
|
610a5b9943 | ||
|
|
18b5f2047c | ||
|
|
17db193faa | ||
|
|
78476630cd | ||
|
|
7a41f36b13 | ||
|
|
c315b8e700 | ||
|
|
f6c1bdccac | ||
|
|
afe5c12afb | ||
|
|
65344b9783 | ||
|
|
9f71fb940d | ||
|
|
28a327c57a | ||
|
|
6aba1bce62 | ||
|
|
5db220c568 | ||
|
|
8e63a4a8d7 | ||
|
|
175f17b131 | ||
|
|
9aec1f51ed | ||
|
|
760e2ff592 | ||
|
|
d17ea2d62a | ||
|
|
6351ba7f5d | ||
|
|
a1ba3b1ac3 | ||
|
|
d78b4d9ab4 | ||
|
|
1292c85d2a | ||
|
|
c44fd7332c | ||
|
|
45a2826df8 | ||
|
|
abb8134761 | ||
|
|
c879599871 | ||
|
|
a71b2a1de6 | ||
|
|
38b20e6158 | ||
|
|
a8a0da1e3c | ||
|
|
d3e7aab796 | ||
|
|
f539c24571 | ||
|
|
2068073e8c | ||
|
|
04d36194c9 | ||
|
|
e8eda51b27 | ||
|
|
aa5d304a2e | ||
|
|
9dab7c9132 | ||
|
|
4562606c54 | ||
|
|
fa933ada85 | ||
|
|
bc1d11bf42 | ||
|
|
5ac2e7044e | ||
|
|
7f741468dd | ||
|
|
715e4c7d73 | ||
|
|
ef473bbc8d | ||
|
|
43ac5e0343 | ||
|
|
da15408a35 | ||
|
|
bd00338690 | ||
|
|
c63ccb5bd9 | ||
|
|
0e9906ea65 | ||
|
|
ff0e786202 | ||
|
|
deee943c3a | ||
|
|
40f38fcb46 | ||
|
|
35ec676f35 | ||
|
|
c345a79962 | ||
|
|
b9c7d1a115 | ||
|
|
bf459a17ba | ||
|
|
15befae65f | ||
|
|
2340e9b3f5 | ||
|
|
7044782689 | ||
|
|
ef4776f697 | ||
|
|
946ba02969 | ||
|
|
9a6ff408b4 | ||
|
|
c07ea4f63d | ||
|
|
2af974b381 | ||
|
|
7e6bdf8b04 | ||
|
|
707c485212 | ||
|
|
5145aa7609 | ||
|
|
496990c096 | ||
|
|
52de22469f | ||
|
|
c27f163623 | ||
|
|
29a61abfe3 | ||
|
|
1580fb5fa7 | ||
|
|
b9763aa28a | ||
|
|
f9eefae1ad | ||
|
|
d16c1f259b | ||
|
|
f0650d1f76 | ||
|
|
d9710ce1af | ||
|
|
e5a4b9a5ac | ||
|
|
ebad48481e | ||
|
|
4503dab267 | ||
|
|
d73acd13cb | ||
|
|
a955794acd | ||
|
|
03d754cb50 | ||
|
|
8dc22e2a63 | ||
|
|
bacdc190e7 | ||
|
|
5943c75873 | ||
|
|
8db695932e | ||
|
|
a70d6a5193 | ||
|
|
4869a8ce22 | ||
|
|
d5d9ecd71c | ||
|
|
52119eadc2 | ||
|
|
334b4d5ef9 | ||
|
|
d7fc2dfb46 | ||
|
|
120469c8bf | ||
|
|
259d6f2b69 | ||
|
|
dce436fb30 | ||
|
|
44cb4e8e77 | ||
|
|
88767a84d1 | ||
|
|
8b03477d2d | ||
|
|
10a2b36dc9 | ||
|
|
b1ccdacd98 | ||
|
|
8811011286 | ||
|
|
b8c764ad70 | ||
|
|
4d69b2eb75 | ||
|
|
5adc6c0a46 | ||
|
|
915b08d8a7 | ||
|
|
2997b12367 | ||
|
|
82f0ee2240 | ||
|
|
26bef8b918 | ||
|
|
82f553ec0d | ||
|
|
4ea85b5eaf | ||
|
|
45efe5a947 | ||
|
|
c2b320dd6a | ||
|
|
6b2d264414 | ||
|
|
9a742cbe93 | ||
|
|
f96f2f101b | ||
|
|
6318a976b5 | ||
|
|
c2c39a0cd6 | ||
|
|
657e64d903 | ||
|
|
a2e681a09f | ||
|
|
1fc3b2aa0a | ||
|
|
5465296ba6 | ||
|
|
fbfb8838fd | ||
|
|
130261c75f | ||
|
|
c8c954d862 | ||
|
|
abfa707f69 | ||
|
|
a01f326f7e | ||
|
|
e377711c7e | ||
|
|
4c40b5f187 | ||
|
|
a8c5264e17 | ||
|
|
9e8f76b749 | ||
|
|
43f8ed55ea | ||
|
|
5d5f14c799 | ||
|
|
9db01a7836 | ||
|
|
1e4a96883f | ||
|
|
a34dc25b34 | ||
|
|
90d3954e8c | ||
|
|
aeb43b7d37 | ||
|
|
7dedcaddb6 | ||
|
|
02463a5cb2 | ||
|
|
4c18763e55 | ||
|
|
8a9a1b59a4 | ||
|
|
8d9b282376 | ||
|
|
ffdc457dea | ||
|
|
015ac85a83 | ||
|
|
ee01a602ff | ||
|
|
f7f4207902 | ||
|
|
37bae48bd9 | ||
|
|
2ef21f929f | ||
|
|
6b77d71b88 | ||
|
|
e2946e0cd1 | ||
|
|
887e7a4f0a | ||
|
|
c4a8ab8a19 | ||
|
|
a3dda5d5a2 | ||
|
|
948279a67d | ||
|
|
cf97e25b4a | ||
|
|
72e4eb2418 | ||
|
|
e241a4cc2a | ||
|
|
3accc65d44 | ||
|
|
fd6f23f4a6 | ||
|
|
5ad56c553f | ||
|
|
81471d8ffe | ||
|
|
e46bf6300e | ||
|
|
a0e87867b7 | ||
|
|
00d5d843a2 | ||
|
|
3ce7cf2713 |
@@ -1,61 +1,40 @@
|
||||
# Ignore everything by default, selectively add things to context
|
||||
*
|
||||
classic/run
|
||||
|
||||
# Platform - Libs
|
||||
!autogpt_platform/autogpt_libs/autogpt_libs/
|
||||
!autogpt_platform/autogpt_libs/pyproject.toml
|
||||
!autogpt_platform/autogpt_libs/poetry.lock
|
||||
!autogpt_platform/autogpt_libs/README.md
|
||||
|
||||
# Platform - Backend
|
||||
!autogpt_platform/backend/backend/
|
||||
!autogpt_platform/backend/migrations/
|
||||
!autogpt_platform/backend/schema.prisma
|
||||
!autogpt_platform/backend/pyproject.toml
|
||||
!autogpt_platform/backend/poetry.lock
|
||||
!autogpt_platform/backend/README.md
|
||||
|
||||
# Platform - Market
|
||||
!autogpt_platform/market/market/
|
||||
!autogpt_platform/market/scripts.py
|
||||
!autogpt_platform/market/schema.prisma
|
||||
!autogpt_platform/market/pyproject.toml
|
||||
!autogpt_platform/market/poetry.lock
|
||||
!autogpt_platform/market/README.md
|
||||
|
||||
# Platform - Frontend
|
||||
!autogpt_platform/frontend/src/
|
||||
!autogpt_platform/frontend/public/
|
||||
!autogpt_platform/frontend/package.json
|
||||
!autogpt_platform/frontend/yarn.lock
|
||||
!autogpt_platform/frontend/tsconfig.json
|
||||
!autogpt_platform/frontend/README.md
|
||||
## config
|
||||
!autogpt_platform/frontend/*.config.*
|
||||
!autogpt_platform/frontend/.env.*
|
||||
|
||||
# Classic - AutoGPT
|
||||
# AutoGPT
|
||||
!classic/original_autogpt/autogpt/
|
||||
!classic/original_autogpt/pyproject.toml
|
||||
!classic/original_autogpt/poetry.lock
|
||||
!classic/original_autogpt/README.md
|
||||
!classic/original_autogpt/tests/
|
||||
|
||||
# Classic - Benchmark
|
||||
# Benchmark
|
||||
!classic/benchmark/agbenchmark/
|
||||
!classic/benchmark/pyproject.toml
|
||||
!classic/benchmark/poetry.lock
|
||||
!classic/benchmark/README.md
|
||||
|
||||
# Classic - Forge
|
||||
# Forge
|
||||
!classic/forge/
|
||||
!classic/forge/pyproject.toml
|
||||
!classic/forge/poetry.lock
|
||||
!classic/forge/README.md
|
||||
|
||||
# Classic - Frontend
|
||||
# Frontend
|
||||
!classic/frontend/build/web/
|
||||
|
||||
# Platform
|
||||
!autogpt_platform/
|
||||
|
||||
# Explicitly re-ignore some folders
|
||||
.*
|
||||
**/__pycache__
|
||||
|
||||
autogpt_platform/frontend/.next/
|
||||
autogpt_platform/frontend/node_modules
|
||||
autogpt_platform/frontend/.env.example
|
||||
autogpt_platform/frontend/.env.local
|
||||
autogpt_platform/backend/.env
|
||||
autogpt_platform/backend/.venv/
|
||||
|
||||
autogpt_platform/market/.env
|
||||
|
||||
22
.github/dependabot.yml
vendored
22
.github/dependabot.yml
vendored
@@ -89,6 +89,28 @@ updates:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
# market (Poetry project)
|
||||
- package-ecosystem: "pip"
|
||||
directory: "autogpt_platform/market"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 10
|
||||
target-branch: "dev"
|
||||
commit-message:
|
||||
prefix: "chore(market/deps)"
|
||||
prefix-development: "chore(market/deps-dev)"
|
||||
groups:
|
||||
production-dependencies:
|
||||
dependency-type: "production"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
development-dependencies:
|
||||
dependency-type: "development"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
|
||||
# GitHub Actions
|
||||
- package-ecosystem: "github-actions"
|
||||
|
||||
@@ -35,6 +35,12 @@ jobs:
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
- name: Run Market Migrations
|
||||
working-directory: ./autogpt_platform/market
|
||||
run: |
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.MARKET_DATABASE_URL }}
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
|
||||
@@ -37,6 +37,13 @@ jobs:
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.BACKEND_DATABASE_URL }}
|
||||
|
||||
- name: Run Market Migrations
|
||||
working-directory: ./autogpt_platform/market
|
||||
run: |
|
||||
python -m prisma migrate deploy
|
||||
env:
|
||||
DATABASE_URL: ${{ secrets.MARKET_DATABASE_URL }}
|
||||
|
||||
trigger:
|
||||
needs: migrate
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
2
.github/workflows/platform-backend-ci.yml
vendored
2
.github/workflows/platform-backend-ci.yml
vendored
@@ -81,7 +81,7 @@ jobs:
|
||||
|
||||
- name: Check poetry.lock
|
||||
run: |
|
||||
poetry lock
|
||||
poetry lock --no-update
|
||||
|
||||
if ! git diff --quiet poetry.lock; then
|
||||
echo "Error: poetry.lock not up to date."
|
||||
|
||||
7
.github/workflows/platform-frontend-ci.yml
vendored
7
.github/workflows/platform-frontend-ci.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
browser: [chromium, webkit]
|
||||
browser: [chromium, firefox, webkit]
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
@@ -88,11 +88,6 @@ jobs:
|
||||
run: |
|
||||
yarn test --project=${{ matrix.browser }}
|
||||
|
||||
- name: Print Docker Compose logs in debug mode
|
||||
if: runner.debug
|
||||
run: |
|
||||
docker compose -f ../docker-compose.yml logs
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
if: ${{ !cancelled() }}
|
||||
with:
|
||||
|
||||
126
.github/workflows/platform-market-ci.yml
vendored
Normal file
126
.github/workflows/platform-market-ci.yml
vendored
Normal file
@@ -0,0 +1,126 @@
|
||||
name: AutoGPT Platform - Backend CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master, dev, ci-test*]
|
||||
paths:
|
||||
- ".github/workflows/platform-market-ci.yml"
|
||||
- "autogpt_platform/market/**"
|
||||
pull_request:
|
||||
branches: [master, dev, release-*]
|
||||
paths:
|
||||
- ".github/workflows/platform-market-ci.yml"
|
||||
- "autogpt_platform/market/**"
|
||||
merge_group:
|
||||
|
||||
concurrency:
|
||||
group: ${{ format('backend-ci-{0}', github.head_ref && format('{0}-{1}', github.event_name, github.event.pull_request.number) || github.sha) }}
|
||||
cancel-in-progress: ${{ startsWith(github.event_name, 'pull_request') }}
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: autogpt_platform/market
|
||||
|
||||
jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup Supabase
|
||||
uses: supabase/setup-cli@v1
|
||||
with:
|
||||
version: latest
|
||||
|
||||
- id: get_date
|
||||
name: Get date
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpt_platform/market/poetry.lock') }}
|
||||
|
||||
- name: Install Poetry (Unix)
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Generate Prisma Client
|
||||
run: poetry run prisma generate
|
||||
|
||||
- id: supabase
|
||||
name: Start Supabase
|
||||
working-directory: .
|
||||
run: |
|
||||
supabase init
|
||||
supabase start --exclude postgres-meta,realtime,storage-api,imgproxy,inbucket,studio,edge-runtime,logflare,vector,supavisor
|
||||
supabase status -o env | sed 's/="/=/; s/"$//' >> $GITHUB_OUTPUT
|
||||
# outputs:
|
||||
# DB_URL, API_URL, GRAPHQL_URL, ANON_KEY, SERVICE_ROLE_KEY, JWT_SECRET
|
||||
|
||||
- name: Run Database Migrations
|
||||
run: poetry run prisma migrate dev --name updates
|
||||
env:
|
||||
DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
|
||||
- id: lint
|
||||
name: Run Linter
|
||||
run: poetry run lint
|
||||
|
||||
# Tests comment out because they do not work with prisma mock, nor have they been updated since they were created
|
||||
# - name: Run pytest with coverage
|
||||
# run: |
|
||||
# if [[ "${{ runner.debug }}" == "1" ]]; then
|
||||
# poetry run pytest -s -vv -o log_cli=true -o log_cli_level=DEBUG test
|
||||
# else
|
||||
# poetry run pytest -s -vv test
|
||||
# fi
|
||||
# if: success() || (failure() && steps.lint.outcome == 'failure')
|
||||
# env:
|
||||
# LOG_LEVEL: ${{ runner.debug && 'DEBUG' || 'INFO' }}
|
||||
# DATABASE_URL: ${{ steps.supabase.outputs.DB_URL }}
|
||||
# SUPABASE_URL: ${{ steps.supabase.outputs.API_URL }}
|
||||
# SUPABASE_SERVICE_ROLE_KEY: ${{ steps.supabase.outputs.SERVICE_ROLE_KEY }}
|
||||
# SUPABASE_JWT_SECRET: ${{ steps.supabase.outputs.JWT_SECRET }}
|
||||
# REDIS_HOST: 'localhost'
|
||||
# REDIS_PORT: '6379'
|
||||
# REDIS_PASSWORD: 'testpassword'
|
||||
|
||||
env:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
RUN_ENV: local
|
||||
PORT: 8080
|
||||
|
||||
# - name: Upload coverage reports to Codecov
|
||||
# uses: codecov/codecov-action@v4
|
||||
# with:
|
||||
# token: ${{ secrets.CODECOV_TOKEN }}
|
||||
# flags: backend,${{ runner.os }}
|
||||
@@ -110,7 +110,7 @@ repos:
|
||||
- id: isort
|
||||
name: Lint (isort) - AutoGPT Platform - Backend
|
||||
alias: isort-platform-backend
|
||||
entry: poetry -P autogpt_platform/backend run isort -p backend
|
||||
entry: poetry -C autogpt_platform/backend run isort -p backend
|
||||
files: ^autogpt_platform/backend/
|
||||
types: [file, python]
|
||||
language: system
|
||||
@@ -118,7 +118,7 @@ repos:
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - AutoGPT
|
||||
alias: isort-classic-autogpt
|
||||
entry: poetry -P classic/original_autogpt run isort -p autogpt
|
||||
entry: poetry -C classic/original_autogpt run isort -p autogpt
|
||||
files: ^classic/original_autogpt/
|
||||
types: [file, python]
|
||||
language: system
|
||||
@@ -126,7 +126,7 @@ repos:
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Forge
|
||||
alias: isort-classic-forge
|
||||
entry: poetry -P classic/forge run isort -p forge
|
||||
entry: poetry -C classic/forge run isort -p forge
|
||||
files: ^classic/forge/
|
||||
types: [file, python]
|
||||
language: system
|
||||
@@ -134,7 +134,7 @@ repos:
|
||||
- id: isort
|
||||
name: Lint (isort) - Classic - Benchmark
|
||||
alias: isort-classic-benchmark
|
||||
entry: poetry -P classic/benchmark run isort -p agbenchmark
|
||||
entry: poetry -C classic/benchmark run isort -p agbenchmark
|
||||
files: ^classic/benchmark/
|
||||
types: [file, python]
|
||||
language: system
|
||||
@@ -178,6 +178,7 @@ repos:
|
||||
name: Typecheck - AutoGPT Platform - Backend
|
||||
alias: pyright-platform-backend
|
||||
entry: poetry -C autogpt_platform/backend run pyright
|
||||
args: [-p, autogpt_platform/backend, autogpt_platform/backend]
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^autogpt_platform/(backend/((backend|test)/|(\w+\.py|poetry\.lock)$)|autogpt_libs/(autogpt_libs/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
types: [file]
|
||||
@@ -188,6 +189,7 @@ repos:
|
||||
name: Typecheck - AutoGPT Platform - Libs
|
||||
alias: pyright-platform-libs
|
||||
entry: poetry -C autogpt_platform/autogpt_libs run pyright
|
||||
args: [-p, autogpt_platform/autogpt_libs, autogpt_platform/autogpt_libs]
|
||||
files: ^autogpt_platform/autogpt_libs/(autogpt_libs/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
@@ -197,6 +199,7 @@ repos:
|
||||
name: Typecheck - Classic - AutoGPT
|
||||
alias: pyright-classic-autogpt
|
||||
entry: poetry -C classic/original_autogpt run pyright
|
||||
args: [-p, classic/original_autogpt, classic/original_autogpt]
|
||||
# include forge source (since it's a path dependency) but exclude *_test.py files:
|
||||
files: ^(classic/original_autogpt/((autogpt|scripts|tests)/|poetry\.lock$)|classic/forge/(forge/.*(?<!_test)\.py|poetry\.lock)$)
|
||||
types: [file]
|
||||
@@ -207,6 +210,7 @@ repos:
|
||||
name: Typecheck - Classic - Forge
|
||||
alias: pyright-classic-forge
|
||||
entry: poetry -C classic/forge run pyright
|
||||
args: [-p, classic/forge, classic/forge]
|
||||
files: ^classic/forge/(forge/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
@@ -216,6 +220,7 @@ repos:
|
||||
name: Typecheck - Classic - Benchmark
|
||||
alias: pyright-classic-benchmark
|
||||
entry: poetry -C classic/benchmark run pyright
|
||||
args: [-p, classic/benchmark, classic/benchmark]
|
||||
files: ^classic/benchmark/(agbenchmark/|tests/|poetry\.lock$)
|
||||
types: [file]
|
||||
language: system
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
from typing import Annotated, Any, Literal, Optional, TypedDict
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr, field_serializer
|
||||
|
||||
|
||||
class _BaseCredentials(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
provider: str
|
||||
title: Optional[str]
|
||||
|
||||
@field_serializer("*")
|
||||
def dump_secret_strings(value: Any, _info):
|
||||
if isinstance(value, SecretStr):
|
||||
return value.get_secret_value()
|
||||
return value
|
||||
|
||||
|
||||
class OAuth2Credentials(_BaseCredentials):
|
||||
type: Literal["oauth2"] = "oauth2"
|
||||
username: Optional[str]
|
||||
"""Username of the third-party service user that these credentials belong to"""
|
||||
access_token: SecretStr
|
||||
access_token_expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the access token expires (if at all)"""
|
||||
refresh_token: Optional[SecretStr]
|
||||
refresh_token_expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the refresh token expires (if at all)"""
|
||||
scopes: list[str]
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def bearer(self) -> str:
|
||||
return f"Bearer {self.access_token.get_secret_value()}"
|
||||
|
||||
|
||||
class APIKeyCredentials(_BaseCredentials):
|
||||
type: Literal["api_key"] = "api_key"
|
||||
api_key: SecretStr
|
||||
expires_at: Optional[int]
|
||||
"""Unix timestamp (seconds) indicating when the API key expires (if at all)"""
|
||||
|
||||
def bearer(self) -> str:
|
||||
return f"Bearer {self.api_key.get_secret_value()}"
|
||||
|
||||
|
||||
Credentials = Annotated[
|
||||
OAuth2Credentials | APIKeyCredentials,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
CredentialsType = Literal["api_key", "oauth2"]
|
||||
|
||||
|
||||
class OAuthState(BaseModel):
|
||||
token: str
|
||||
provider: str
|
||||
expires_at: int
|
||||
code_verifier: Optional[str] = None
|
||||
scopes: list[str]
|
||||
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
||||
|
||||
|
||||
class UserMetadata(BaseModel):
|
||||
integration_credentials: list[Credentials] = Field(default_factory=list)
|
||||
integration_oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UserMetadataRaw(TypedDict, total=False):
|
||||
integration_credentials: list[dict]
|
||||
integration_oauth_states: list[dict]
|
||||
|
||||
|
||||
class UserIntegrations(BaseModel):
|
||||
credentials: list[Credentials] = Field(default_factory=list)
|
||||
oauth_states: list[OAuthState] = Field(default_factory=list)
|
||||
@@ -31,8 +31,7 @@ class RedisKeyedMutex:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
lock.release()
|
||||
|
||||
def acquire(self, key: Any) -> "RedisLock":
|
||||
"""Acquires and returns a lock with the given key"""
|
||||
@@ -46,7 +45,7 @@ class RedisKeyedMutex:
|
||||
return lock
|
||||
|
||||
def release(self, key: Any):
|
||||
if (lock := self.locks.get(key)) and lock.locked() and lock.owned():
|
||||
if lock := self.locks.get(key):
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
|
||||
268
autogpt_platform/autogpt_libs/poetry.lock
generated
268
autogpt_platform/autogpt_libs/poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -1091,19 +1091,22 @@ pyasn1 = ">=0.4.6,<0.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pydantic"
|
||||
version = "2.10.3"
|
||||
version = "2.9.2"
|
||||
description = "Data validation using Python type hints"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic-2.10.3-py3-none-any.whl", hash = "sha256:be04d85bbc7b65651c5f8e6b9976ed9c6f41782a55524cef079a34a0bb82144d"},
|
||||
{file = "pydantic-2.10.3.tar.gz", hash = "sha256:cb5ac360ce894ceacd69c403187900a02c4b20b693a9dd1d643e1effab9eadf9"},
|
||||
{file = "pydantic-2.9.2-py3-none-any.whl", hash = "sha256:f048cec7b26778210e28a0459867920654d48e5e62db0958433636cde4254f12"},
|
||||
{file = "pydantic-2.9.2.tar.gz", hash = "sha256:d155cef71265d1e9807ed1c32b4c8deec042a44a50a4188b25ac67ecd81a9c0f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
annotated-types = ">=0.6.0"
|
||||
pydantic-core = "2.27.1"
|
||||
typing-extensions = ">=4.12.2"
|
||||
pydantic-core = "2.23.4"
|
||||
typing-extensions = [
|
||||
{version = ">=4.12.2", markers = "python_version >= \"3.13\""},
|
||||
{version = ">=4.6.1", markers = "python_version < \"3.13\""},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
email = ["email-validator (>=2.0.0)"]
|
||||
@@ -1111,111 +1114,100 @@ timezone = ["tzdata"]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-core"
|
||||
version = "2.27.1"
|
||||
version = "2.23.4"
|
||||
description = "Core functionality for Pydantic validation and serialization"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:71a5e35c75c021aaf400ac048dacc855f000bdfed91614b4a726f7432f1f3d6a"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f82d068a2d6ecfc6e054726080af69a6764a10015467d7d7b9f66d6ed5afa23b"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:121ceb0e822f79163dd4699e4c54f5ad38b157084d97b34de8b232bcaad70278"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4603137322c18eaf2e06a4495f426aa8d8388940f3c457e7548145011bb68e05"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a33cd6ad9017bbeaa9ed78a2e0752c5e250eafb9534f308e7a5f7849b0b1bfb4"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:15cc53a3179ba0fcefe1e3ae50beb2784dede4003ad2dfd24f81bba4b23a454f"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45d9c5eb9273aa50999ad6adc6be5e0ecea7e09dbd0d31bd0c65a55a2592ca08"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8bf7b66ce12a2ac52d16f776b31d16d91033150266eb796967a7e4621707e4f6"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:655d7dd86f26cb15ce8a431036f66ce0318648f8853d709b4167786ec2fa4807"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:5556470f1a2157031e676f776c2bc20acd34c1990ca5f7e56f1ebf938b9ab57c"},
|
||||
{file = "pydantic_core-2.27.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f69ed81ab24d5a3bd93861c8c4436f54afdf8e8cc421562b0c7504cf3be58206"},
|
||||
{file = "pydantic_core-2.27.1-cp310-none-win32.whl", hash = "sha256:f5a823165e6d04ccea61a9f0576f345f8ce40ed533013580e087bd4d7442b52c"},
|
||||
{file = "pydantic_core-2.27.1-cp310-none-win_amd64.whl", hash = "sha256:57866a76e0b3823e0b56692d1a0bf722bffb324839bb5b7226a7dbd6c9a40b17"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:ac3b20653bdbe160febbea8aa6c079d3df19310d50ac314911ed8cc4eb7f8cb8"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a5a8e19d7c707c4cadb8c18f5f60c843052ae83c20fa7d44f41594c644a1d330"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f7059ca8d64fea7f238994c97d91f75965216bcbe5f695bb44f354893f11d52"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bed0f8a0eeea9fb72937ba118f9db0cb7e90773462af7962d382445f3005e5a4"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a3cb37038123447cf0f3ea4c74751f6a9d7afef0eb71aa07bf5f652b5e6a132c"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84286494f6c5d05243456e04223d5a9417d7f443c3b76065e75001beb26f88de"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acc07b2cfc5b835444b44a9956846b578d27beeacd4b52e45489e93276241025"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4fefee876e07a6e9aad7a8c8c9f85b0cdbe7df52b8a9552307b09050f7512c7e"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:258c57abf1188926c774a4c94dd29237e77eda19462e5bb901d88adcab6af919"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:35c14ac45fcfdf7167ca76cc80b2001205a8d5d16d80524e13508371fb8cdd9c"},
|
||||
{file = "pydantic_core-2.27.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d1b26e1dff225c31897696cab7d4f0a315d4c0d9e8666dbffdb28216f3b17fdc"},
|
||||
{file = "pydantic_core-2.27.1-cp311-none-win32.whl", hash = "sha256:2cdf7d86886bc6982354862204ae3b2f7f96f21a3eb0ba5ca0ac42c7b38598b9"},
|
||||
{file = "pydantic_core-2.27.1-cp311-none-win_amd64.whl", hash = "sha256:3af385b0cee8df3746c3f406f38bcbfdc9041b5c2d5ce3e5fc6637256e60bbc5"},
|
||||
{file = "pydantic_core-2.27.1-cp311-none-win_arm64.whl", hash = "sha256:81f2ec23ddc1b476ff96563f2e8d723830b06dceae348ce02914a37cb4e74b89"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9cbd94fc661d2bab2bc702cddd2d3370bbdcc4cd0f8f57488a81bcce90c7a54f"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5f8c4718cd44ec1580e180cb739713ecda2bdee1341084c1467802a417fe0f02"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15aae984e46de8d376df515f00450d1522077254ef6b7ce189b38ecee7c9677c"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1ba5e3963344ff25fc8c40da90f44b0afca8cfd89d12964feb79ac1411a260ac"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:992cea5f4f3b29d6b4f7f1726ed8ee46c8331c6b4eed6db5b40134c6fe1768bb"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0325336f348dbee6550d129b1627cb8f5351a9dc91aad141ffb96d4937bd9529"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7597c07fbd11515f654d6ece3d0e4e5093edc30a436c63142d9a4b8e22f19c35"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:3bbd5d8cc692616d5ef6fbbbd50dbec142c7e6ad9beb66b78a96e9c16729b089"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:dc61505e73298a84a2f317255fcc72b710b72980f3a1f670447a21efc88f8381"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:e1f735dc43da318cad19b4173dd1ffce1d84aafd6c9b782b3abc04a0d5a6f5bb"},
|
||||
{file = "pydantic_core-2.27.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:f4e5658dbffe8843a0f12366a4c2d1c316dbe09bb4dfbdc9d2d9cd6031de8aae"},
|
||||
{file = "pydantic_core-2.27.1-cp312-none-win32.whl", hash = "sha256:672ebbe820bb37988c4d136eca2652ee114992d5d41c7e4858cdd90ea94ffe5c"},
|
||||
{file = "pydantic_core-2.27.1-cp312-none-win_amd64.whl", hash = "sha256:66ff044fd0bb1768688aecbe28b6190f6e799349221fb0de0e6f4048eca14c16"},
|
||||
{file = "pydantic_core-2.27.1-cp312-none-win_arm64.whl", hash = "sha256:9a3b0793b1bbfd4146304e23d90045f2a9b5fd5823aa682665fbdaf2a6c28f3e"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f216dbce0e60e4d03e0c4353c7023b202d95cbaeff12e5fd2e82ea0a66905073"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a2e02889071850bbfd36b56fd6bc98945e23670773bc7a76657e90e6b6603c08"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42b0e23f119b2b456d07ca91b307ae167cc3f6c846a7b169fca5326e32fdc6cf"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:764be71193f87d460a03f1f7385a82e226639732214b402f9aa61f0d025f0737"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c00666a3bd2f84920a4e94434f5974d7bbc57e461318d6bb34ce9cdbbc1f6b2"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3ccaa88b24eebc0f849ce0a4d09e8a408ec5a94afff395eb69baf868f5183107"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c65af9088ac534313e1963443d0ec360bb2b9cba6c2909478d22c2e363d98a51"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:206b5cf6f0c513baffaeae7bd817717140770c74528f3e4c3e1cec7871ddd61a"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:062f60e512fc7fff8b8a9d680ff0ddaaef0193dba9fa83e679c0c5f5fbd018bc"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:a0697803ed7d4af5e4c1adf1670af078f8fcab7a86350e969f454daf598c4960"},
|
||||
{file = "pydantic_core-2.27.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:58ca98a950171f3151c603aeea9303ef6c235f692fe555e883591103da709b23"},
|
||||
{file = "pydantic_core-2.27.1-cp313-none-win32.whl", hash = "sha256:8065914ff79f7eab1599bd80406681f0ad08f8e47c880f17b416c9f8f7a26d05"},
|
||||
{file = "pydantic_core-2.27.1-cp313-none-win_amd64.whl", hash = "sha256:ba630d5e3db74c79300d9a5bdaaf6200172b107f263c98a0539eeecb857b2337"},
|
||||
{file = "pydantic_core-2.27.1-cp313-none-win_arm64.whl", hash = "sha256:45cf8588c066860b623cd11c4ba687f8d7175d5f7ef65f7129df8a394c502de5"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:5897bec80a09b4084aee23f9b73a9477a46c3304ad1d2d07acca19723fb1de62"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d0165ab2914379bd56908c02294ed8405c252250668ebcb438a55494c69f44ab"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b9af86e1d8e4cfc82c2022bfaa6f459381a50b94a29e95dcdda8442d6d83864"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f6c8a66741c5f5447e047ab0ba7a1c61d1e95580d64bce852e3df1f895c4067"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a42d6a8156ff78981f8aa56eb6394114e0dedb217cf8b729f438f643608cbcd"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64c65f40b4cd8b0e049a8edde07e38b476da7e3aaebe63287c899d2cff253fa5"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdcf339322a3fae5cbd504edcefddd5a50d9ee00d968696846f089b4432cf78"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bf99c8404f008750c846cb4ac4667b798a9f7de673ff719d705d9b2d6de49c5f"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8f1edcea27918d748c7e5e4d917297b2a0ab80cad10f86631e488b7cddf76a36"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_armv7l.whl", hash = "sha256:159cac0a3d096f79ab6a44d77a961917219707e2a130739c64d4dd46281f5c2a"},
|
||||
{file = "pydantic_core-2.27.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:029d9757eb621cc6e1848fa0b0310310de7301057f623985698ed7ebb014391b"},
|
||||
{file = "pydantic_core-2.27.1-cp38-none-win32.whl", hash = "sha256:a28af0695a45f7060e6f9b7092558a928a28553366519f64083c63a44f70e618"},
|
||||
{file = "pydantic_core-2.27.1-cp38-none-win_amd64.whl", hash = "sha256:2d4567c850905d5eaaed2f7a404e61012a51caf288292e016360aa2b96ff38d4"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:e9386266798d64eeb19dd3677051f5705bf873e98e15897ddb7d76f477131967"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4228b5b646caa73f119b1ae756216b59cc6e2267201c27d3912b592c5e323b60"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b3dfe500de26c52abe0477dde16192ac39c98f05bf2d80e76102d394bd13854"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:aee66be87825cdf72ac64cb03ad4c15ffef4143dbf5c113f64a5ff4f81477bf9"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3b748c44bb9f53031c8cbc99a8a061bc181c1000c60a30f55393b6e9c45cc5bd"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ca038c7f6a0afd0b2448941b6ef9d5e1949e999f9e5517692eb6da58e9d44be"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6e0bd57539da59a3e4671b90a502da9a28c72322a4f17866ba3ac63a82c4498e"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ac6c2c45c847bbf8f91930d88716a0fb924b51e0c6dad329b793d670ec5db792"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b94d4ba43739bbe8b0ce4262bcc3b7b9f31459ad120fb595627eaeb7f9b9ca01"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:00e6424f4b26fe82d44577b4c842d7df97c20be6439e8e685d0d715feceb9fb9"},
|
||||
{file = "pydantic_core-2.27.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:38de0a70160dd97540335b7ad3a74571b24f1dc3ed33f815f0880682e6880131"},
|
||||
{file = "pydantic_core-2.27.1-cp39-none-win32.whl", hash = "sha256:7ccebf51efc61634f6c2344da73e366c75e735960b5654b63d7e6f69a5885fa3"},
|
||||
{file = "pydantic_core-2.27.1-cp39-none-win_amd64.whl", hash = "sha256:a57847b090d7892f123726202b7daa20df6694cbd583b67a592e856bff603d6c"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3fa80ac2bd5856580e242dbc202db873c60a01b20309c8319b5c5986fbe53ce6"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d950caa237bb1954f1b8c9227b5065ba6875ac9771bb8ec790d956a699b78676"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e4216e64d203e39c62df627aa882f02a2438d18a5f21d7f721621f7a5d3611d"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:02a3d637bd387c41d46b002f0e49c52642281edacd2740e5a42f7017feea3f2c"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:161c27ccce13b6b0c8689418da3885d3220ed2eae2ea5e9b2f7f3d48f1d52c27"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:19910754e4cc9c63bc1c7f6d73aa1cfee82f42007e407c0f413695c2f7ed777f"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:e173486019cc283dc9778315fa29a363579372fe67045e971e89b6365cc035ed"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:af52d26579b308921b73b956153066481f064875140ccd1dfd4e77db89dbb12f"},
|
||||
{file = "pydantic_core-2.27.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:981fb88516bd1ae8b0cbbd2034678a39dedc98752f264ac9bc5839d3923fa04c"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5fde892e6c697ce3e30c61b239330fc5d569a71fefd4eb6512fc6caec9dd9e2f"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:816f5aa087094099fff7edabb5e01cc370eb21aa1a1d44fe2d2aefdfb5599b31"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c10c309e18e443ddb108f0ef64e8729363adbfd92d6d57beec680f6261556f3"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98476c98b02c8e9b2eec76ac4156fd006628b1b2d0ef27e548ffa978393fd154"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c3027001c28434e7ca5a6e1e527487051136aa81803ac812be51802150d880dd"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:7699b1df36a48169cdebda7ab5a2bac265204003f153b4bd17276153d997670a"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1c39b07d90be6b48968ddc8c19e7585052088fd7ec8d568bb31ff64c70ae3c97"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:46ccfe3032b3915586e469d4972973f893c0a2bb65669194a5bdea9bacc088c2"},
|
||||
{file = "pydantic_core-2.27.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:62ba45e21cf6571d7f716d903b5b7b6d2617e2d5d67c0923dc47b9d41369f840"},
|
||||
{file = "pydantic_core-2.27.1.tar.gz", hash = "sha256:62a763352879b84aa31058fc931884055fd75089cccbd9d58bb6afd01141b235"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b10bd51f823d891193d4717448fab065733958bdb6a6b351967bd349d48d5c9b"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4fc714bdbfb534f94034efaa6eadd74e5b93c8fa6315565a222f7b6f42ca1166"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63e46b3169866bd62849936de036f901a9356e36376079b05efa83caeaa02ceb"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed1a53de42fbe34853ba90513cea21673481cd81ed1be739f7f2efb931b24916"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cfdd16ab5e59fc31b5e906d1a3f666571abc367598e3e02c83403acabc092e07"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:255a8ef062cbf6674450e668482456abac99a5583bbafb73f9ad469540a3a232"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a7cd62e831afe623fbb7aabbb4fe583212115b3ef38a9f6b71869ba644624a2"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f09e2ff1f17c2b51f2bc76d1cc33da96298f0a036a137f5440ab3ec5360b624f"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e38e63e6f3d1cec5a27e0afe90a085af8b6806ee208b33030e65b6516353f1a3"},
|
||||
{file = "pydantic_core-2.23.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0dbd8dbed2085ed23b5c04afa29d8fd2771674223135dc9bc937f3c09284d071"},
|
||||
{file = "pydantic_core-2.23.4-cp310-none-win32.whl", hash = "sha256:6531b7ca5f951d663c339002e91aaebda765ec7d61b7d1e3991051906ddde119"},
|
||||
{file = "pydantic_core-2.23.4-cp310-none-win_amd64.whl", hash = "sha256:7c9129eb40958b3d4500fa2467e6a83356b3b61bfff1b414c7361d9220f9ae8f"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:77733e3892bb0a7fa797826361ce8a9184d25c8dffaec60b7ffe928153680ba8"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b84d168f6c48fabd1f2027a3d1bdfe62f92cade1fb273a5d68e621da0e44e6d"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df49e7a0861a8c36d089c1ed57d308623d60416dab2647a4a17fe050ba85de0e"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ff02b6d461a6de369f07ec15e465a88895f3223eb75073ffea56b84d9331f607"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:996a38a83508c54c78a5f41456b0103c30508fed9abcad0a59b876d7398f25fd"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d97683ddee4723ae8c95d1eddac7c192e8c552da0c73a925a89fa8649bf13eea"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:216f9b2d7713eb98cb83c80b9c794de1f6b7e3145eef40400c62e86cee5f4e1e"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6f783e0ec4803c787bcea93e13e9932edab72068f68ecffdf86a99fd5918878b"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d0776dea117cf5272382634bd2a5c1b6eb16767c223c6a5317cd3e2a757c61a0"},
|
||||
{file = "pydantic_core-2.23.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d5f7a395a8cf1621939692dba2a6b6a830efa6b3cee787d82c7de1ad2930de64"},
|
||||
{file = "pydantic_core-2.23.4-cp311-none-win32.whl", hash = "sha256:74b9127ffea03643e998e0c5ad9bd3811d3dac8c676e47db17b0ee7c3c3bf35f"},
|
||||
{file = "pydantic_core-2.23.4-cp311-none-win_amd64.whl", hash = "sha256:98d134c954828488b153d88ba1f34e14259284f256180ce659e8d83e9c05eaa3"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f3e0da4ebaef65158d4dfd7d3678aad692f7666877df0002b8a522cdf088f231"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f69a8e0b033b747bb3e36a44e7732f0c99f7edd5cea723d45bc0d6e95377ffee"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:723314c1d51722ab28bfcd5240d858512ffd3116449c557a1336cbe3919beb87"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb2802e667b7051a1bebbfe93684841cc9351004e2badbd6411bf357ab8d5ac8"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d18ca8148bebe1b0a382a27a8ee60350091a6ddaf475fa05ef50dc35b5df6327"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:33e3d65a85a2a4a0dc3b092b938a4062b1a05f3a9abde65ea93b233bca0e03f2"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:128585782e5bfa515c590ccee4b727fb76925dd04a98864182b22e89a4e6ed36"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:68665f4c17edcceecc112dfed5dbe6f92261fb9d6054b47d01bf6371a6196126"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:20152074317d9bed6b7a95ade3b7d6054845d70584216160860425f4fbd5ee9e"},
|
||||
{file = "pydantic_core-2.23.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:9261d3ce84fa1d38ed649c3638feefeae23d32ba9182963e465d58d62203bd24"},
|
||||
{file = "pydantic_core-2.23.4-cp312-none-win32.whl", hash = "sha256:4ba762ed58e8d68657fc1281e9bb72e1c3e79cc5d464be146e260c541ec12d84"},
|
||||
{file = "pydantic_core-2.23.4-cp312-none-win_amd64.whl", hash = "sha256:97df63000f4fea395b2824da80e169731088656d1818a11b95f3b173747b6cd9"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:7530e201d10d7d14abce4fb54cfe5b94a0aefc87da539d0346a484ead376c3cc"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:df933278128ea1cd77772673c73954e53a1c95a4fdf41eef97c2b779271bd0bd"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cb3da3fd1b6a5d0279a01877713dbda118a2a4fc6f0d821a57da2e464793f05"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:42c6dcb030aefb668a2b7009c85b27f90e51e6a3b4d5c9bc4c57631292015b0d"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:696dd8d674d6ce621ab9d45b205df149399e4bb9aa34102c970b721554828510"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2971bb5ffe72cc0f555c13e19b23c85b654dd2a8f7ab493c262071377bfce9f6"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8394d940e5d400d04cad4f75c0598665cbb81aecefaca82ca85bd28264af7f9b"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:0dff76e0602ca7d4cdaacc1ac4c005e0ce0dcfe095d5b5259163a80d3a10d327"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7d32706badfe136888bdea71c0def994644e09fff0bfe47441deaed8e96fdbc6"},
|
||||
{file = "pydantic_core-2.23.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:ed541d70698978a20eb63d8c5d72f2cc6d7079d9d90f6b50bad07826f1320f5f"},
|
||||
{file = "pydantic_core-2.23.4-cp313-none-win32.whl", hash = "sha256:3d5639516376dce1940ea36edf408c554475369f5da2abd45d44621cb616f769"},
|
||||
{file = "pydantic_core-2.23.4-cp313-none-win_amd64.whl", hash = "sha256:5a1504ad17ba4210df3a045132a7baeeba5a200e930f57512ee02909fc5c4cb5"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:d4488a93b071c04dc20f5cecc3631fc78b9789dd72483ba15d423b5b3689b555"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:81965a16b675b35e1d09dd14df53f190f9129c0202356ed44ab2728b1c905658"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffa2ebd4c8530079140dd2d7f794a9d9a73cbb8e9d59ffe24c63436efa8f271"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:61817945f2fe7d166e75fbfb28004034b48e44878177fc54d81688e7b85a3665"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29d2c342c4bc01b88402d60189f3df065fb0dda3654744d5a165a5288a657368"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5e11661ce0fd30a6790e8bcdf263b9ec5988e95e63cf901972107efc49218b13"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d18368b137c6295db49ce7218b1a9ba15c5bc254c96d7c9f9e924a9bc7825ad"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ec4e55f79b1c4ffb2eecd8a0cfba9955a2588497d96851f4c8f99aa4a1d39b12"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:374a5e5049eda9e0a44c696c7ade3ff355f06b1fe0bb945ea3cac2bc336478a2"},
|
||||
{file = "pydantic_core-2.23.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5c364564d17da23db1106787675fc7af45f2f7b58b4173bfdd105564e132e6fb"},
|
||||
{file = "pydantic_core-2.23.4-cp38-none-win32.whl", hash = "sha256:d7a80d21d613eec45e3d41eb22f8f94ddc758a6c4720842dc74c0581f54993d6"},
|
||||
{file = "pydantic_core-2.23.4-cp38-none-win_amd64.whl", hash = "sha256:5f5ff8d839f4566a474a969508fe1c5e59c31c80d9e140566f9a37bba7b8d556"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a4fa4fc04dff799089689f4fd502ce7d59de529fc2f40a2c8836886c03e0175a"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0a7df63886be5e270da67e0966cf4afbae86069501d35c8c1b3b6c168f42cb36"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcedcd19a557e182628afa1d553c3895a9f825b936415d0dbd3cd0bbcfd29b4b"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5f54b118ce5de9ac21c363d9b3caa6c800341e8c47a508787e5868c6b79c9323"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:86d2f57d3e1379a9525c5ab067b27dbb8a0642fb5d454e17a9ac434f9ce523e3"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:de6d1d1b9e5101508cb37ab0d972357cac5235f5c6533d1071964c47139257df"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1278e0d324f6908e872730c9102b0112477a7f7cf88b308e4fc36ce1bdb6d58c"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9a6b5099eeec78827553827f4c6b8615978bb4b6a88e5d9b93eddf8bb6790f55"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:e55541f756f9b3ee346b840103f32779c695a19826a4c442b7954550a0972040"},
|
||||
{file = "pydantic_core-2.23.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a5c7ba8ffb6d6f8f2ab08743be203654bb1aaa8c9dcb09f82ddd34eadb695605"},
|
||||
{file = "pydantic_core-2.23.4-cp39-none-win32.whl", hash = "sha256:37b0fe330e4a58d3c58b24d91d1eb102aeec675a3db4c292ec3928ecd892a9a6"},
|
||||
{file = "pydantic_core-2.23.4-cp39-none-win_amd64.whl", hash = "sha256:1498bec4c05c9c787bde9125cfdcc63a41004ff167f495063191b863399b1a29"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:f455ee30a9d61d3e1a15abd5068827773d6e4dc513e795f380cdd59932c782d5"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1e90d2e3bd2c3863d48525d297cd143fe541be8bbf6f579504b9712cb6b643ec"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2e203fdf807ac7e12ab59ca2bfcabb38c7cf0b33c41efeb00f8e5da1d86af480"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e08277a400de01bc72436a0ccd02bdf596631411f592ad985dcee21445bd0068"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f220b0eea5965dec25480b6333c788fb72ce5f9129e8759ef876a1d805d00801"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:d06b0c8da4f16d1d1e352134427cb194a0a6e19ad5db9161bf32b2113409e728"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:ba1a0996f6c2773bd83e63f18914c1de3c9dd26d55f4ac302a7efe93fb8e7433"},
|
||||
{file = "pydantic_core-2.23.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:9a5bce9d23aac8f0cf0836ecfc033896aa8443b501c58d0602dbfd5bd5b37753"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:78ddaaa81421a29574a682b3179d4cf9e6d405a09b99d93ddcf7e5239c742e21"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:883a91b5dd7d26492ff2f04f40fbb652de40fcc0afe07e8129e8ae779c2110eb"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:88ad334a15b32a791ea935af224b9de1bf99bcd62fabf745d5f3442199d86d59"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:233710f069d251feb12a56da21e14cca67994eab08362207785cf8c598e74577"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:19442362866a753485ba5e4be408964644dd6a09123d9416c54cd49171f50744"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:624e278a7d29b6445e4e813af92af37820fafb6dcc55c012c834f9e26f9aaaef"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f5ef8f42bec47f21d07668a043f077d507e5bf4e668d5c6dfe6aaba89de1a5b8"},
|
||||
{file = "pydantic_core-2.23.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:aea443fffa9fbe3af1a9ba721a87f926fe548d32cab71d188a6ede77d0ff244e"},
|
||||
{file = "pydantic_core-2.23.4.tar.gz", hash = "sha256:2584f7cf844ac4d970fba483a717dbe10c1c1c96a969bf65d61ffe94df1b2863"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1223,13 +1215,13 @@ typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-settings"
|
||||
version = "2.7.0"
|
||||
version = "2.6.1"
|
||||
description = "Settings management using Pydantic"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pydantic_settings-2.7.0-py3-none-any.whl", hash = "sha256:e00c05d5fa6cbbb227c84bd7487c5c1065084119b750df7c8c1a554aed236eb5"},
|
||||
{file = "pydantic_settings-2.7.0.tar.gz", hash = "sha256:ac4bfd4a36831a48dbf8b2d9325425b549a0a6f18cea118436d728eb4f1c4d66"},
|
||||
{file = "pydantic_settings-2.6.1-py3-none-any.whl", hash = "sha256:7fb0637c786a558d3103436278a7c4f1cfd29ba8973238a50c5bb9a55387da87"},
|
||||
{file = "pydantic_settings-2.6.1.tar.gz", hash = "sha256:e0f92546d8a9923cb8941689abf85d6601a8c19a23e97a34b2964a2e3f813ca0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1243,13 +1235,13 @@ yaml = ["pyyaml (>=6.0.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pyjwt"
|
||||
version = "2.10.1"
|
||||
version = "2.10.0"
|
||||
description = "JSON Web Token implementation in Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb"},
|
||||
{file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"},
|
||||
{file = "PyJWT-2.10.0-py3-none-any.whl", hash = "sha256:543b77207db656de204372350926bed5a86201c4cbff159f623f79c7bb487a15"},
|
||||
{file = "pyjwt-2.10.0.tar.gz", hash = "sha256:7628a7eb7938959ac1b26e819a1df0fd3259505627b575e4bad6d08f76db695c"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -1282,20 +1274,20 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments
|
||||
|
||||
[[package]]
|
||||
name = "pytest-asyncio"
|
||||
version = "0.25.0"
|
||||
version = "0.24.0"
|
||||
description = "Pytest support for asyncio"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pytest_asyncio-0.25.0-py3-none-any.whl", hash = "sha256:db5432d18eac6b7e28b46dcd9b69921b55c3b1086e85febfe04e70b18d9e81b3"},
|
||||
{file = "pytest_asyncio-0.25.0.tar.gz", hash = "sha256:8c0610303c9e0442a5db8604505fc0f545456ba1528824842b37b4a626cbf609"},
|
||||
{file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"},
|
||||
{file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pytest = ">=8.2,<9"
|
||||
|
||||
[package.extras]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"]
|
||||
docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
|
||||
testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
|
||||
|
||||
[[package]]
|
||||
@@ -1415,29 +1407,29 @@ pyasn1 = ">=0.1.3"
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.8.6"
|
||||
version = "0.8.2"
|
||||
description = "An extremely fast Python linter and code formatter, written in Rust."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "ruff-0.8.6-py3-none-linux_armv6l.whl", hash = "sha256:defed167955d42c68b407e8f2e6f56ba52520e790aba4ca707a9c88619e580e3"},
|
||||
{file = "ruff-0.8.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:54799ca3d67ae5e0b7a7ac234baa657a9c1784b48ec954a094da7c206e0365b1"},
|
||||
{file = "ruff-0.8.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e88b8f6d901477c41559ba540beeb5a671e14cd29ebd5683903572f4b40a9807"},
|
||||
{file = "ruff-0.8.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0509e8da430228236a18a677fcdb0c1f102dd26d5520f71f79b094963322ed25"},
|
||||
{file = "ruff-0.8.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:91a7ddb221779871cf226100e677b5ea38c2d54e9e2c8ed847450ebbdf99b32d"},
|
||||
{file = "ruff-0.8.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:248b1fb3f739d01d528cc50b35ee9c4812aa58cc5935998e776bf8ed5b251e75"},
|
||||
{file = "ruff-0.8.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:bc3c083c50390cf69e7e1b5a5a7303898966be973664ec0c4a4acea82c1d4315"},
|
||||
{file = "ruff-0.8.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52d587092ab8df308635762386f45f4638badb0866355b2b86760f6d3c076188"},
|
||||
{file = "ruff-0.8.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:61323159cf21bc3897674e5adb27cd9e7700bab6b84de40d7be28c3d46dc67cf"},
|
||||
{file = "ruff-0.8.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ae4478b1471fc0c44ed52a6fb787e641a2ac58b1c1f91763bafbc2faddc5117"},
|
||||
{file = "ruff-0.8.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:0c000a471d519b3e6cfc9c6680025d923b4ca140ce3e4612d1a2ef58e11f11fe"},
|
||||
{file = "ruff-0.8.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:9257aa841e9e8d9b727423086f0fa9a86b6b420fbf4bf9e1465d1250ce8e4d8d"},
|
||||
{file = "ruff-0.8.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:45a56f61b24682f6f6709636949ae8cc82ae229d8d773b4c76c09ec83964a95a"},
|
||||
{file = "ruff-0.8.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:496dd38a53aa173481a7d8866bcd6451bd934d06976a2505028a50583e001b76"},
|
||||
{file = "ruff-0.8.6-py3-none-win32.whl", hash = "sha256:e169ea1b9eae61c99b257dc83b9ee6c76f89042752cb2d83486a7d6e48e8f764"},
|
||||
{file = "ruff-0.8.6-py3-none-win_amd64.whl", hash = "sha256:f1d70bef3d16fdc897ee290d7d20da3cbe4e26349f62e8a0274e7a3f4ce7a905"},
|
||||
{file = "ruff-0.8.6-py3-none-win_arm64.whl", hash = "sha256:7d7fc2377a04b6e04ffe588caad613d0c460eb2ecba4c0ccbbfe2bc973cbc162"},
|
||||
{file = "ruff-0.8.6.tar.gz", hash = "sha256:dcad24b81b62650b0eb8814f576fc65cfee8674772a6e24c9b747911801eeaa5"},
|
||||
{file = "ruff-0.8.2-py3-none-linux_armv6l.whl", hash = "sha256:c49ab4da37e7c457105aadfd2725e24305ff9bc908487a9bf8d548c6dad8bb3d"},
|
||||
{file = "ruff-0.8.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ec016beb69ac16be416c435828be702ee694c0d722505f9c1f35e1b9c0cc1bf5"},
|
||||
{file = "ruff-0.8.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f05cdf8d050b30e2ba55c9b09330b51f9f97d36d4673213679b965d25a785f3c"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60f578c11feb1d3d257b2fb043ddb47501ab4816e7e221fbb0077f0d5d4e7b6f"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbd5cf9b0ae8f30eebc7b360171bd50f59ab29d39f06a670b3e4501a36ba5897"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b402ddee3d777683de60ff76da801fa7e5e8a71038f57ee53e903afbcefdaa58"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:705832cd7d85605cb7858d8a13d75993c8f3ef1397b0831289109e953d833d29"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:32096b41aaf7a5cc095fa45b4167b890e4c8d3fd217603f3634c92a541de7248"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e769083da9439508833cfc7c23e351e1809e67f47c50248250ce1ac52c21fb93"},
|
||||
{file = "ruff-0.8.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fe716592ae8a376c2673fdfc1f5c0c193a6d0411f90a496863c99cd9e2ae25d"},
|
||||
{file = "ruff-0.8.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:81c148825277e737493242b44c5388a300584d73d5774defa9245aaef55448b0"},
|
||||
{file = "ruff-0.8.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d261d7850c8367704874847d95febc698a950bf061c9475d4a8b7689adc4f7fa"},
|
||||
{file = "ruff-0.8.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1ca4e3a87496dc07d2427b7dd7ffa88a1e597c28dad65ae6433ecb9f2e4f022f"},
|
||||
{file = "ruff-0.8.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:729850feed82ef2440aa27946ab39c18cb4a8889c1128a6d589ffa028ddcfc22"},
|
||||
{file = "ruff-0.8.2-py3-none-win32.whl", hash = "sha256:ac42caaa0411d6a7d9594363294416e0e48fc1279e1b0e948391695db2b3d5b1"},
|
||||
{file = "ruff-0.8.2-py3-none-win_amd64.whl", hash = "sha256:2aae99ec70abf43372612a838d97bfe77d45146254568d94926e8ed5bbb409ea"},
|
||||
{file = "ruff-0.8.2-py3-none-win_arm64.whl", hash = "sha256:fb88e2a506b70cfbc2de6fae6681c4f944f7dd5f2fe87233a7233d888bad73e8"},
|
||||
{file = "ruff-0.8.2.tar.gz", hash = "sha256:b84f4f414dda8ac7f75075c1fa0b905ac0ff25361f42e6d5da681a465e0f78e5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1852,4 +1844,4 @@ type = ["pytest-mypy"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<4.0"
|
||||
content-hash = "bf1b0125759dadb1369fff05ffba64fea3e82b9b7a43d0068e1c80974a4ebc1c"
|
||||
content-hash = "2d92dded4ebeff76a3d1dd0bb6b348e8177f053918141b4c4828442da318a907"
|
||||
|
||||
@@ -10,10 +10,10 @@ packages = [{ include = "autogpt_libs" }]
|
||||
colorama = "^0.4.6"
|
||||
expiringdict = "^1.2.2"
|
||||
google-cloud-logging = "^3.11.3"
|
||||
pydantic = "^2.10.3"
|
||||
pydantic-settings = "^2.7.0"
|
||||
pyjwt = "^2.10.1"
|
||||
pytest-asyncio = "^0.25.0"
|
||||
pydantic = "^2.9.2"
|
||||
pydantic-settings = "^2.6.1"
|
||||
pyjwt = "^2.10.0"
|
||||
pytest-asyncio = "^0.24.0"
|
||||
pytest-mock = "^3.14.0"
|
||||
python = ">=3.10,<4.0"
|
||||
python-dotenv = "^1.0.1"
|
||||
@@ -21,7 +21,7 @@ supabase = "^2.10.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
redis = "^5.2.1"
|
||||
ruff = "^0.8.6"
|
||||
ruff = "^0.8.2"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
@@ -58,21 +58,6 @@ GITHUB_CLIENT_SECRET=
|
||||
GOOGLE_CLIENT_ID=
|
||||
GOOGLE_CLIENT_SECRET=
|
||||
|
||||
# Twitter (X) OAuth 2.0 with PKCE Configuration
|
||||
# 1. Create a Twitter Developer Account:
|
||||
# - Visit https://developer.x.com/en and sign up
|
||||
# 2. Set up your application:
|
||||
# - Navigate to Developer Portal > Projects > Create Project
|
||||
# - Add a new app to your project
|
||||
# 3. Configure app settings:
|
||||
# - App Permissions: Read + Write + Direct Messages
|
||||
# - App Type: Web App, Automated App or Bot
|
||||
# - OAuth 2.0 Callback URL: http://localhost:3000/auth/integrations/oauth_callback
|
||||
# - Save your Client ID and Client Secret below
|
||||
TWITTER_CLIENT_ID=
|
||||
TWITTER_CLIENT_SECRET=
|
||||
|
||||
|
||||
## ===== OPTIONAL API KEYS ===== ##
|
||||
|
||||
# LLM
|
||||
@@ -121,18 +106,6 @@ REPLICATE_API_KEY=
|
||||
# Ideogram
|
||||
IDEOGRAM_API_KEY=
|
||||
|
||||
# Fal
|
||||
FAL_API_KEY=
|
||||
|
||||
# Exa
|
||||
EXA_API_KEY=
|
||||
|
||||
# E2B
|
||||
E2B_API_KEY=
|
||||
|
||||
# Nvidia
|
||||
NVIDIA_API_KEY=
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
ENABLE_CLOUD_LOGGING=false
|
||||
|
||||
@@ -17,11 +17,12 @@ RUN apt-get install -y libz-dev
|
||||
RUN apt-get install -y libssl-dev
|
||||
RUN apt-get install -y postgresql-client
|
||||
|
||||
ENV POETRY_VERSION=1.8.3
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
@@ -31,21 +32,25 @@ RUN pip3 install poetry
|
||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||
WORKDIR /app/autogpt_platform/backend
|
||||
RUN poetry install --no-ansi --no-root
|
||||
RUN poetry config virtualenvs.create false \
|
||||
&& poetry install --no-interaction --no-ansi
|
||||
|
||||
# Generate Prisma client
|
||||
COPY autogpt_platform/backend/schema.prisma ./
|
||||
RUN poetry run prisma generate
|
||||
RUN poetry config virtualenvs.create false \
|
||||
&& poetry run prisma generate
|
||||
|
||||
FROM python:3.11.10-slim-bookworm AS server_dependencies
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
ENV POETRY_HOME=/opt/poetry \
|
||||
POETRY_NO_INTERACTION=1 \
|
||||
POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV POETRY_VERSION=1.8.3
|
||||
ENV POETRY_HOME=/opt/poetry
|
||||
ENV POETRY_NO_INTERACTION=1
|
||||
ENV POETRY_VIRTUALENVS_CREATE=false
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
|
||||
# Upgrade pip and setuptools to fix security vulnerabilities
|
||||
RUN pip3 install --upgrade pip setuptools
|
||||
|
||||
@@ -71,7 +76,6 @@ WORKDIR /app/autogpt_platform/backend
|
||||
FROM server_dependencies AS server
|
||||
|
||||
COPY autogpt_platform/backend /app/autogpt_platform/backend
|
||||
RUN poetry install --no-ansi --only-root
|
||||
|
||||
ENV DATABASE_URL=""
|
||||
ENV PORT=8000
|
||||
|
||||
@@ -76,11 +76,7 @@ class AgentExecutorBlock(Block):
|
||||
)
|
||||
|
||||
if not event.node_id:
|
||||
if event.status in [
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.TERMINATED,
|
||||
ExecutionStatus.FAILED,
|
||||
]:
|
||||
if event.status in [ExecutionStatus.COMPLETED, ExecutionStatus.FAILED]:
|
||||
logger.info(f"Execution {log_id} ended with status {event.status}")
|
||||
break
|
||||
else:
|
||||
|
||||
@@ -241,7 +241,7 @@ class AgentOutputBlock(Block):
|
||||
advanced=True,
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="The format string to be used to format the recorded_value. Use Jinja2 syntax.",
|
||||
description="The format string to be used to format the recorded_value.",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockManualWebhookConfig,
|
||||
BlockOutput,
|
||||
BlockSchema,
|
||||
)
|
||||
from backend.data.model import SchemaField
|
||||
from backend.integrations.webhooks.compass import CompassWebhookType
|
||||
|
||||
|
||||
class Transcription(BaseModel):
|
||||
text: str
|
||||
speaker: str
|
||||
end: float
|
||||
start: float
|
||||
duration: float
|
||||
|
||||
|
||||
class TranscriptionDataModel(BaseModel):
|
||||
date: str
|
||||
transcription: str
|
||||
transcriptions: list[Transcription]
|
||||
|
||||
|
||||
class CompassAITriggerBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
payload: TranscriptionDataModel = SchemaField(hidden=True)
|
||||
|
||||
class Output(BlockSchema):
|
||||
transcription: str = SchemaField(
|
||||
description="The contents of the compass transcription."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9464a020-ed1d-49e1-990f-7f2ac924a2b7",
|
||||
description="This block will output the contents of the compass transcription.",
|
||||
categories={BlockCategory.HARDWARE},
|
||||
input_schema=CompassAITriggerBlock.Input,
|
||||
output_schema=CompassAITriggerBlock.Output,
|
||||
webhook_config=BlockManualWebhookConfig(
|
||||
provider="compass",
|
||||
webhook_type=CompassWebhookType.TRANSCRIPTION,
|
||||
),
|
||||
test_input=[
|
||||
{"input": "Hello, World!"},
|
||||
{"input": "Hello, World!", "data": "Existing Data"},
|
||||
],
|
||||
# test_output=[
|
||||
# ("output", "Hello, World!"), # No data provided, so trigger is returned
|
||||
# ("output", "Existing Data"), # Data is provided, so data is returned.
|
||||
# ],
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
yield "transcription", input_data.payload.transcription
|
||||
@@ -1,87 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
|
||||
class ContentRetrievalSettings(BaseModel):
|
||||
text: Optional[dict] = SchemaField(
|
||||
description="Text content settings",
|
||||
default={"maxCharacters": 1000, "includeHtmlTags": False},
|
||||
advanced=True,
|
||||
)
|
||||
highlights: Optional[dict] = SchemaField(
|
||||
description="Highlight settings",
|
||||
default={
|
||||
"numSentences": 3,
|
||||
"highlightsPerUrl": 3,
|
||||
"query": "",
|
||||
},
|
||||
advanced=True,
|
||||
)
|
||||
summary: Optional[dict] = SchemaField(
|
||||
description="Summary settings",
|
||||
default={"query": ""},
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class ExaContentsBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
ids: List[str] = SchemaField(
|
||||
description="Array of document IDs obtained from searches",
|
||||
)
|
||||
contents: ContentRetrievalSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
default=ContentRetrievalSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: list = SchemaField(
|
||||
description="List of document contents",
|
||||
default=[],
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c52be83f-f8cd-4180-b243-af35f986b461",
|
||||
description="Retrieves document contents using Exa's contents API",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaContentsBlock.Input,
|
||||
output_schema=ExaContentsBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/contents"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
"ids": input_data.ids,
|
||||
"text": input_data.contents.text,
|
||||
"highlights": input_data.contents.highlights,
|
||||
"summary": input_data.contents.summary,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "results", []
|
||||
@@ -1,54 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TextSettings(BaseModel):
|
||||
max_characters: int = SchemaField(
|
||||
default=1000,
|
||||
description="Maximum number of characters to return",
|
||||
placeholder="1000",
|
||||
)
|
||||
include_html_tags: bool = SchemaField(
|
||||
default=False,
|
||||
description="Whether to include HTML tags in the text",
|
||||
placeholder="False",
|
||||
)
|
||||
|
||||
|
||||
class HighlightSettings(BaseModel):
|
||||
num_sentences: int = SchemaField(
|
||||
default=3,
|
||||
description="Number of sentences per highlight",
|
||||
placeholder="3",
|
||||
)
|
||||
highlights_per_url: int = SchemaField(
|
||||
default=3,
|
||||
description="Number of highlights per URL",
|
||||
placeholder="3",
|
||||
)
|
||||
|
||||
|
||||
class SummarySettings(BaseModel):
|
||||
query: Optional[str] = SchemaField(
|
||||
default="",
|
||||
description="Query string for summarization",
|
||||
placeholder="Enter query",
|
||||
)
|
||||
|
||||
|
||||
class ContentSettings(BaseModel):
|
||||
text: TextSettings = SchemaField(
|
||||
default=TextSettings(),
|
||||
description="Text content settings",
|
||||
)
|
||||
highlights: HighlightSettings = SchemaField(
|
||||
default=HighlightSettings(),
|
||||
description="Highlight settings",
|
||||
)
|
||||
summary: SummarySettings = SchemaField(
|
||||
default=SummarySettings(),
|
||||
description="Summary settings",
|
||||
)
|
||||
@@ -1,76 +1,84 @@
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
)
|
||||
from backend.blocks.exa.helpers import ContentSettings
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
|
||||
class ContentSettings(BaseModel):
|
||||
text: dict = SchemaField(
|
||||
description="Text content settings",
|
||||
default={"maxCharacters": 1000, "includeHtmlTags": False},
|
||||
)
|
||||
highlights: dict = SchemaField(
|
||||
description="Highlight settings",
|
||||
default={"numSentences": 3, "highlightsPerUrl": 3},
|
||||
)
|
||||
summary: dict = SchemaField(
|
||||
description="Summary settings",
|
||||
default={"query": ""},
|
||||
)
|
||||
|
||||
|
||||
class ExaSearchBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
query: str = SchemaField(description="The search query")
|
||||
use_auto_prompt: bool = SchemaField(
|
||||
useAutoprompt: bool = SchemaField(
|
||||
description="Whether to use autoprompt",
|
||||
default=True,
|
||||
advanced=True,
|
||||
)
|
||||
type: str = SchemaField(
|
||||
description="Type of search",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
category: str = SchemaField(
|
||||
description="Category to search within",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
number_of_results: int = SchemaField(
|
||||
numResults: int = SchemaField(
|
||||
description="Number of results to return",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
includeDomains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default=[],
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
excludeDomains: List[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
startCrawlDate: datetime = SchemaField(
|
||||
description="Start date for crawled content",
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
endCrawlDate: datetime = SchemaField(
|
||||
description="End date for crawled content",
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
startPublishedDate: datetime = SchemaField(
|
||||
description="Start date for published content",
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
endPublishedDate: datetime = SchemaField(
|
||||
description="End date for published content",
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
includeText: List[str] = SchemaField(
|
||||
description="Text patterns to include",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
excludeText: List[str] = SchemaField(
|
||||
description="Text patterns to exclude",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -99,38 +107,44 @@ class ExaSearchBlock(Block):
|
||||
|
||||
payload = {
|
||||
"query": input_data.query,
|
||||
"useAutoprompt": input_data.use_auto_prompt,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.dict(),
|
||||
}
|
||||
|
||||
date_field_mapping = {
|
||||
"start_crawl_date": "startCrawlDate",
|
||||
"end_crawl_date": "endCrawlDate",
|
||||
"start_published_date": "startPublishedDate",
|
||||
"end_published_date": "endPublishedDate",
|
||||
"useAutoprompt": input_data.useAutoprompt,
|
||||
"numResults": input_data.numResults,
|
||||
"contents": {
|
||||
"text": {"maxCharacters": 1000, "includeHtmlTags": False},
|
||||
"highlights": {
|
||||
"numSentences": 3,
|
||||
"highlightsPerUrl": 3,
|
||||
},
|
||||
"summary": {"query": ""},
|
||||
},
|
||||
}
|
||||
|
||||
# Add dates if they exist
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
date_fields = [
|
||||
"startCrawlDate",
|
||||
"endCrawlDate",
|
||||
"startPublishedDate",
|
||||
"endPublishedDate",
|
||||
]
|
||||
for field in date_fields:
|
||||
value = getattr(input_data, field, None)
|
||||
if value:
|
||||
payload[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
optional_field_mapping = {
|
||||
"type": "type",
|
||||
"category": "category",
|
||||
"include_domains": "includeDomains",
|
||||
"exclude_domains": "excludeDomains",
|
||||
"include_text": "includeText",
|
||||
"exclude_text": "excludeText",
|
||||
}
|
||||
payload[field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
# Add other fields
|
||||
for input_field, api_field in optional_field_mapping.items():
|
||||
value = getattr(input_data, input_field)
|
||||
optional_fields = [
|
||||
"type",
|
||||
"category",
|
||||
"includeDomains",
|
||||
"excludeDomains",
|
||||
"includeText",
|
||||
"excludeText",
|
||||
]
|
||||
|
||||
for field in optional_fields:
|
||||
value = getattr(input_data, field)
|
||||
if value: # Only add non-empty values
|
||||
payload[api_field] = value
|
||||
payload[field] = value
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
|
||||
@@ -1,128 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, List
|
||||
|
||||
from backend.blocks.exa._auth import (
|
||||
ExaCredentials,
|
||||
ExaCredentialsField,
|
||||
ExaCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
from .helpers import ContentSettings
|
||||
|
||||
|
||||
class ExaFindSimilarBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: ExaCredentialsInput = ExaCredentialsField()
|
||||
url: str = SchemaField(
|
||||
description="The url for which you would like to find similar links"
|
||||
)
|
||||
number_of_results: int = SchemaField(
|
||||
description="Number of results to return",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
include_domains: List[str] = SchemaField(
|
||||
description="Domains to include in search",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
exclude_domains: List[str] = SchemaField(
|
||||
description="Domains to exclude from search",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
start_crawl_date: datetime = SchemaField(
|
||||
description="Start date for crawled content",
|
||||
)
|
||||
end_crawl_date: datetime = SchemaField(
|
||||
description="End date for crawled content",
|
||||
)
|
||||
start_published_date: datetime = SchemaField(
|
||||
description="Start date for published content",
|
||||
)
|
||||
end_published_date: datetime = SchemaField(
|
||||
description="End date for published content",
|
||||
)
|
||||
include_text: List[str] = SchemaField(
|
||||
description="Text patterns to include (max 1 string, up to 5 words)",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
exclude_text: List[str] = SchemaField(
|
||||
description="Text patterns to exclude (max 1 string, up to 5 words)",
|
||||
default=[],
|
||||
advanced=True,
|
||||
)
|
||||
contents: ContentSettings = SchemaField(
|
||||
description="Content retrieval settings",
|
||||
default=ContentSettings(),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
results: List[Any] = SchemaField(
|
||||
description="List of similar documents with title, URL, published date, author, and score",
|
||||
default=[],
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5e7315d1-af61-4a0c-9350-7c868fa7438a",
|
||||
description="Finds similar links using Exa's findSimilar API",
|
||||
categories={BlockCategory.SEARCH},
|
||||
input_schema=ExaFindSimilarBlock.Input,
|
||||
output_schema=ExaFindSimilarBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: ExaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://api.exa.ai/findSimilar"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": credentials.api_key.get_secret_value(),
|
||||
}
|
||||
|
||||
payload = {
|
||||
"url": input_data.url,
|
||||
"numResults": input_data.number_of_results,
|
||||
"contents": input_data.contents.dict(),
|
||||
}
|
||||
|
||||
optional_field_mapping = {
|
||||
"include_domains": "includeDomains",
|
||||
"exclude_domains": "excludeDomains",
|
||||
"include_text": "includeText",
|
||||
"exclude_text": "excludeText",
|
||||
}
|
||||
|
||||
# Add optional fields if they have values
|
||||
for input_field, api_field in optional_field_mapping.items():
|
||||
value = getattr(input_data, input_field)
|
||||
if value: # Only add non-empty values
|
||||
payload[api_field] = value
|
||||
|
||||
date_field_mapping = {
|
||||
"start_crawl_date": "startCrawlDate",
|
||||
"end_crawl_date": "endCrawlDate",
|
||||
"start_published_date": "startPublishedDate",
|
||||
"end_published_date": "endPublishedDate",
|
||||
}
|
||||
|
||||
# Add dates if they exist
|
||||
for input_field, api_field in date_field_mapping.items():
|
||||
value = getattr(input_data, input_field, None)
|
||||
if value:
|
||||
payload[api_field] = value.strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
yield "results", data.get("results", [])
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "results", []
|
||||
@@ -699,420 +699,3 @@ class GithubDeleteBranchBlock(Block):
|
||||
input_data.branch,
|
||||
)
|
||||
yield "status", status
|
||||
|
||||
|
||||
class GithubCreateFileBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path where the file should be created",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="Content to write to the file",
|
||||
placeholder="File content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch where the file should be created",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Create new file",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
url: str = SchemaField(description="URL of the created file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the file creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8fd132ac-b917-428a-8159-d62893e8a3fe",
|
||||
description="This block creates a new file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateFileBlock.Input,
|
||||
output_schema=GithubCreateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Test content",
|
||||
"branch": "main",
|
||||
"commit_message": "Create test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "abc123"),
|
||||
],
|
||||
test_mock={
|
||||
"create_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"abc123",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
# Convert content to base64
|
||||
content_bytes = content.encode("utf-8")
|
||||
content_base64 = base64.b64encode(content_bytes).decode("utf-8")
|
||||
|
||||
# Create the file using the GitHub API
|
||||
contents_url = f"{repo_url}/contents/{file_path}"
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"branch": branch,
|
||||
}
|
||||
response = api.put(contents_url, json=data)
|
||||
result = response.json()
|
||||
|
||||
return result["content"]["html_url"], result["commit"]["sha"]
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = self.create_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubUpdateFileBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
file_path: str = SchemaField(
|
||||
description="Path to the file to update",
|
||||
placeholder="path/to/file.txt",
|
||||
)
|
||||
content: str = SchemaField(
|
||||
description="New content for the file",
|
||||
placeholder="Updated content here",
|
||||
)
|
||||
branch: str = SchemaField(
|
||||
description="Branch containing the file",
|
||||
default="main",
|
||||
)
|
||||
commit_message: str = SchemaField(
|
||||
description="Message for the commit",
|
||||
default="Update file",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
url: str = SchemaField(description="URL of the updated file")
|
||||
sha: str = SchemaField(description="SHA of the commit")
|
||||
error: str = SchemaField(description="Error message if the file update failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30be12a4-57cb-4aa4-baf5-fcc68d136076",
|
||||
description="This block updates an existing file in a GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubUpdateFileBlock.Input,
|
||||
output_schema=GithubUpdateFileBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"file_path": "test/file.txt",
|
||||
"content": "Updated content",
|
||||
"branch": "main",
|
||||
"commit_message": "Update test file",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/repo/blob/main/test/file.txt"),
|
||||
("sha", "def456"),
|
||||
],
|
||||
test_mock={
|
||||
"update_file": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/repo/blob/main/test/file.txt",
|
||||
"def456",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_file(
|
||||
credentials: GithubCredentials,
|
||||
repo_url: str,
|
||||
file_path: str,
|
||||
content: str,
|
||||
branch: str,
|
||||
commit_message: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials)
|
||||
|
||||
# First get the current file to get its SHA
|
||||
contents_url = f"{repo_url}/contents/{file_path}"
|
||||
params = {"ref": branch}
|
||||
response = api.get(contents_url, params=params)
|
||||
current_file = response.json()
|
||||
|
||||
# Convert new content to base64
|
||||
content_bytes = content.encode("utf-8")
|
||||
content_base64 = base64.b64encode(content_bytes).decode("utf-8")
|
||||
|
||||
# Update the file
|
||||
data = {
|
||||
"message": commit_message,
|
||||
"content": content_base64,
|
||||
"sha": current_file["sha"],
|
||||
"branch": branch,
|
||||
}
|
||||
response = api.put(contents_url, json=data)
|
||||
result = response.json()
|
||||
|
||||
return result["content"]["html_url"], result["commit"]["sha"]
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, sha = self.update_file(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
input_data.file_path,
|
||||
input_data.content,
|
||||
input_data.branch,
|
||||
input_data.commit_message,
|
||||
)
|
||||
yield "url", url
|
||||
yield "sha", sha
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubCreateRepositoryBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
name: str = SchemaField(
|
||||
description="Name of the repository to create",
|
||||
placeholder="my-new-repo",
|
||||
)
|
||||
description: str = SchemaField(
|
||||
description="Description of the repository",
|
||||
placeholder="A description of the repository",
|
||||
default="",
|
||||
)
|
||||
private: bool = SchemaField(
|
||||
description="Whether the repository should be private",
|
||||
default=False,
|
||||
)
|
||||
auto_init: bool = SchemaField(
|
||||
description="Whether to initialize the repository with a README",
|
||||
default=True,
|
||||
)
|
||||
gitignore_template: str = SchemaField(
|
||||
description="Git ignore template to use (e.g., Python, Node, Java)",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
url: str = SchemaField(description="URL of the created repository")
|
||||
clone_url: str = SchemaField(description="Git clone URL of the repository")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the repository creation failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="029ec3b8-1cfd-46d3-b6aa-28e4a706efd1",
|
||||
description="This block creates a new GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubCreateRepositoryBlock.Input,
|
||||
output_schema=GithubCreateRepositoryBlock.Output,
|
||||
test_input={
|
||||
"name": "test-repo",
|
||||
"description": "A test repository",
|
||||
"private": False,
|
||||
"auto_init": True,
|
||||
"gitignore_template": "Python",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("url", "https://github.com/owner/test-repo"),
|
||||
("clone_url", "https://github.com/owner/test-repo.git"),
|
||||
],
|
||||
test_mock={
|
||||
"create_repository": lambda *args, **kwargs: (
|
||||
"https://github.com/owner/test-repo",
|
||||
"https://github.com/owner/test-repo.git",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_repository(
|
||||
credentials: GithubCredentials,
|
||||
name: str,
|
||||
description: str,
|
||||
private: bool,
|
||||
auto_init: bool,
|
||||
gitignore_template: str,
|
||||
) -> tuple[str, str]:
|
||||
api = get_api(credentials, convert_urls=False) # Disable URL conversion
|
||||
data = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"private": private,
|
||||
"auto_init": auto_init,
|
||||
}
|
||||
|
||||
if gitignore_template:
|
||||
data["gitignore_template"] = gitignore_template
|
||||
|
||||
# Create repository using the user endpoint
|
||||
response = api.post("https://api.github.com/user/repos", json=data)
|
||||
result = response.json()
|
||||
|
||||
return result["html_url"], result["clone_url"]
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
url, clone_url = self.create_repository(
|
||||
credentials,
|
||||
input_data.name,
|
||||
input_data.description,
|
||||
input_data.private,
|
||||
input_data.auto_init,
|
||||
input_data.gitignore_template,
|
||||
)
|
||||
yield "url", url
|
||||
yield "clone_url", clone_url
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
|
||||
class GithubListStargazersBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: GithubCredentialsInput = GithubCredentialsField("repo")
|
||||
repo_url: str = SchemaField(
|
||||
description="URL of the GitHub repository",
|
||||
placeholder="https://github.com/owner/repo",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
class StargazerItem(TypedDict):
|
||||
username: str
|
||||
url: str
|
||||
|
||||
stargazer: StargazerItem = SchemaField(
|
||||
title="Stargazer",
|
||||
description="Stargazers with their username and profile URL",
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if listing stargazers failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a4b9c2d1-e5f6-4g7h-8i9j-0k1l2m3n4o5p", # Generated unique UUID
|
||||
description="This block lists all users who have starred a specified GitHub repository.",
|
||||
categories={BlockCategory.DEVELOPER_TOOLS},
|
||||
input_schema=GithubListStargazersBlock.Input,
|
||||
output_schema=GithubListStargazersBlock.Output,
|
||||
test_input={
|
||||
"repo_url": "https://github.com/owner/repo",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
(
|
||||
"stargazer",
|
||||
{
|
||||
"username": "octocat",
|
||||
"url": "https://github.com/octocat",
|
||||
},
|
||||
)
|
||||
],
|
||||
test_mock={
|
||||
"list_stargazers": lambda *args, **kwargs: [
|
||||
{
|
||||
"username": "octocat",
|
||||
"url": "https://github.com/octocat",
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_stargazers(
|
||||
credentials: GithubCredentials, repo_url: str
|
||||
) -> list[Output.StargazerItem]:
|
||||
api = get_api(credentials)
|
||||
# Add /stargazers to the repo URL to get stargazers endpoint
|
||||
stargazers_url = f"{repo_url}/stargazers"
|
||||
# Set accept header to get starred_at timestamp
|
||||
headers = {"Accept": "application/vnd.github.star+json"}
|
||||
response = api.get(stargazers_url, headers=headers)
|
||||
data = response.json()
|
||||
|
||||
stargazers: list[GithubListStargazersBlock.Output.StargazerItem] = [
|
||||
{
|
||||
"username": stargazer["login"],
|
||||
"url": stargazer["html_url"],
|
||||
}
|
||||
for stargazer in data
|
||||
]
|
||||
return stargazers
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: GithubCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
stargazers = self.list_stargazers(
|
||||
credentials,
|
||||
input_data.repo_url,
|
||||
)
|
||||
yield from (("stargazer", stargazer) for stargazer in stargazers)
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
|
||||
@@ -56,24 +56,15 @@ class SendWebRequestBlock(Block):
|
||||
)
|
||||
|
||||
def run(self, input_data: Input, **kwargs) -> BlockOutput:
|
||||
body = input_data.body
|
||||
|
||||
if input_data.json_format:
|
||||
if isinstance(body, str):
|
||||
try:
|
||||
# Try to parse as JSON first
|
||||
body = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
# If it's not valid JSON and just plain text,
|
||||
# we should send it as plain text instead
|
||||
input_data.json_format = False
|
||||
if isinstance(input_data.body, str):
|
||||
input_data.body = json.loads(input_data.body)
|
||||
|
||||
response = requests.request(
|
||||
input_data.method.value,
|
||||
input_data.url,
|
||||
headers=input_data.headers,
|
||||
json=body if input_data.json_format else None,
|
||||
data=body if not input_data.json_format else None,
|
||||
json=input_data.body if input_data.json_format else None,
|
||||
data=input_data.body if not input_data.json_format else None,
|
||||
)
|
||||
result = response.json() if input_data.json_format else response.text
|
||||
|
||||
|
||||
@@ -26,10 +26,8 @@ from backend.data.model import (
|
||||
)
|
||||
from backend.util import json
|
||||
from backend.util.settings import BehaveAs, Settings
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
fmt = TextFormatter()
|
||||
|
||||
LLMProviderName = Literal[
|
||||
ProviderName.ANTHROPIC,
|
||||
@@ -111,7 +109,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
LLAMA3_1_70B = "llama-3.1-70b-versatile"
|
||||
LLAMA3_1_8B = "llama-3.1-8b-instant"
|
||||
# Ollama models
|
||||
OLLAMA_LLAMA3_2 = "llama3.2"
|
||||
OLLAMA_LLAMA3_8B = "llama3"
|
||||
OLLAMA_LLAMA3_405B = "llama3.1:405b"
|
||||
OLLAMA_DOLPHIN = "dolphin-mistral:latest"
|
||||
@@ -166,7 +163,6 @@ MODEL_METADATA = {
|
||||
# Limited to 16k during preview
|
||||
LlmModel.LLAMA3_1_70B: ModelMetadata("groq", 131072),
|
||||
LlmModel.LLAMA3_1_8B: ModelMetadata("groq", 131072),
|
||||
LlmModel.OLLAMA_LLAMA3_2: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_LLAMA3_8B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_LLAMA3_405B: ModelMetadata("ollama", 8192),
|
||||
LlmModel.OLLAMA_DOLPHIN: ModelMetadata("ollama", 32768),
|
||||
@@ -238,9 +234,7 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
)
|
||||
max_tokens: int | None = SchemaField(
|
||||
advanced=True,
|
||||
@@ -454,8 +448,8 @@ class AIStructuredResponseGeneratorBlock(Block):
|
||||
|
||||
values = input_data.prompt_values
|
||||
if values:
|
||||
input_data.prompt = fmt.format_string(input_data.prompt, values)
|
||||
input_data.sys_prompt = fmt.format_string(input_data.sys_prompt, values)
|
||||
input_data.prompt = input_data.prompt.format(**values)
|
||||
input_data.sys_prompt = input_data.sys_prompt.format(**values)
|
||||
|
||||
if input_data.sys_prompt:
|
||||
prompt.append({"role": "system", "content": input_data.sys_prompt})
|
||||
@@ -582,9 +576,7 @@ class AITextGeneratorBlock(Block):
|
||||
description="Number of times to retry the LLM call if the response does not match the expected format.",
|
||||
)
|
||||
prompt_values: dict[str, str] = SchemaField(
|
||||
advanced=False,
|
||||
default={},
|
||||
description="Values used to fill in the prompt. The values can be used in the prompt by putting them in a double curly braces, e.g. {{variable_name}}.",
|
||||
advanced=False, default={}, description="Values used to fill in the prompt."
|
||||
)
|
||||
ollama_host: str = SchemaField(
|
||||
advanced=True,
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsField, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
NvidiaCredentials = APIKeyCredentials
|
||||
NvidiaCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.NVIDIA],
|
||||
Literal["api_key"],
|
||||
]
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="nvidia",
|
||||
api_key=SecretStr("mock-nvidia-api-key"),
|
||||
title="Mock Nvidia API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
|
||||
def NvidiaCredentialsField() -> NvidiaCredentialsInput:
|
||||
"""Creates an Nvidia credentials input on a block."""
|
||||
return CredentialsField(description="The Nvidia integration requires an API Key.")
|
||||
@@ -1,90 +0,0 @@
|
||||
from backend.blocks.nvidia._auth import (
|
||||
NvidiaCredentials,
|
||||
NvidiaCredentialsField,
|
||||
NvidiaCredentialsInput,
|
||||
)
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.request import requests
|
||||
|
||||
|
||||
class NvidiaDeepfakeDetectBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
credentials: NvidiaCredentialsInput = NvidiaCredentialsField()
|
||||
image_base64: str = SchemaField(
|
||||
description="Image to analyze for deepfakes", image_upload=True
|
||||
)
|
||||
return_image: bool = SchemaField(
|
||||
description="Whether to return the processed image with markings",
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
status: str = SchemaField(
|
||||
description="Detection status (SUCCESS, ERROR, CONTENT_FILTERED)",
|
||||
default="",
|
||||
)
|
||||
image: str = SchemaField(
|
||||
description="Processed image with detection markings (if return_image=True)",
|
||||
default="",
|
||||
image_output=True,
|
||||
)
|
||||
is_deepfake: float = SchemaField(
|
||||
description="Probability that the image is a deepfake (0-1)",
|
||||
default=0.0,
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8c7d0d67-e79c-44f6-92a1-c2600c8aac7f",
|
||||
description="Detects potential deepfakes in images using Nvidia's AI API",
|
||||
categories={BlockCategory.SAFETY},
|
||||
input_schema=NvidiaDeepfakeDetectBlock.Input,
|
||||
output_schema=NvidiaDeepfakeDetectBlock.Output,
|
||||
)
|
||||
|
||||
def run(
|
||||
self, input_data: Input, *, credentials: NvidiaCredentials, **kwargs
|
||||
) -> BlockOutput:
|
||||
url = "https://ai.api.nvidia.com/v1/cv/hive/deepfake-image-detection"
|
||||
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {credentials.api_key.get_secret_value()}",
|
||||
}
|
||||
|
||||
image_data = f"data:image/jpeg;base64,{input_data.image_base64}"
|
||||
|
||||
payload = {
|
||||
"input": [image_data],
|
||||
"return_image": input_data.return_image,
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
result = data.get("data", [{}])[0]
|
||||
|
||||
# Get deepfake probability from first bounding box if any
|
||||
deepfake_prob = 0.0
|
||||
if result.get("bounding_boxes"):
|
||||
deepfake_prob = result["bounding_boxes"][0].get("is_deepfake", 0.0)
|
||||
|
||||
yield "status", result.get("status", "ERROR")
|
||||
yield "is_deepfake", deepfake_prob
|
||||
|
||||
if input_data.return_image:
|
||||
image_data = result.get("image", "")
|
||||
output_data = f"data:image/jpeg;base64,{image_data}"
|
||||
yield "image", output_data
|
||||
else:
|
||||
yield "image", ""
|
||||
|
||||
except Exception as e:
|
||||
yield "error", str(e)
|
||||
yield "status", "ERROR"
|
||||
yield "is_deepfake", 0.0
|
||||
yield "image", ""
|
||||
@@ -141,10 +141,10 @@ class ExtractTextInformationBlock(Block):
|
||||
class FillTextTemplateBlock(Block):
|
||||
class Input(BlockSchema):
|
||||
values: dict[str, Any] = SchemaField(
|
||||
description="Values (dict) to be used in format. These values can be used by putting them in double curly braces in the format template. e.g. {{value_name}}.",
|
||||
description="Values (dict) to be used in format"
|
||||
)
|
||||
format: str = SchemaField(
|
||||
description="Template to format the text using `values`. Use Jinja2 syntax."
|
||||
description="Template to format the text using `values`"
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
@@ -160,7 +160,7 @@ class FillTextTemplateBlock(Block):
|
||||
test_input=[
|
||||
{
|
||||
"values": {"name": "Alice", "hello": "Hello", "world": "World!"},
|
||||
"format": "{{hello}}, {{ world }} {{name}}",
|
||||
"format": "{hello}, {world} {{name}}",
|
||||
},
|
||||
{
|
||||
"values": {"list": ["Hello", " World!"]},
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import (
|
||||
CredentialsField,
|
||||
CredentialsMetaInput,
|
||||
OAuth2Credentials,
|
||||
ProviderName,
|
||||
)
|
||||
from backend.integrations.oauth.twitter import TwitterOAuthHandler
|
||||
from backend.util.settings import Secrets
|
||||
|
||||
# --8<-- [start:TwitterOAuthIsConfigured]
|
||||
secrets = Secrets()
|
||||
TWITTER_OAUTH_IS_CONFIGURED = bool(
|
||||
secrets.twitter_client_id and secrets.twitter_client_secret
|
||||
)
|
||||
# --8<-- [end:TwitterOAuthIsConfigured]
|
||||
|
||||
TwitterCredentials = OAuth2Credentials
|
||||
TwitterCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.TWITTER], Literal["oauth2"]
|
||||
]
|
||||
|
||||
|
||||
# Currently, We are getting all the permission from the Twitter API initally
|
||||
# In future, If we need to add incremental permission, we can use these requested_scopes
|
||||
def TwitterCredentialsField(scopes: list[str]) -> TwitterCredentialsInput:
|
||||
"""
|
||||
Creates a Twitter credentials input on a block.
|
||||
|
||||
Params:
|
||||
scopes: The authorization scopes needed for the block to work.
|
||||
"""
|
||||
return CredentialsField(
|
||||
# required_scopes=set(scopes),
|
||||
required_scopes=set(TwitterOAuthHandler.DEFAULT_SCOPES + scopes),
|
||||
description="The Twitter integration requires OAuth2 authentication.",
|
||||
)
|
||||
|
||||
|
||||
TEST_CREDENTIALS = OAuth2Credentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="twitter",
|
||||
access_token=SecretStr("mock-twitter-access-token"),
|
||||
refresh_token=SecretStr("mock-twitter-refresh-token"),
|
||||
access_token_expires_at=1234567890,
|
||||
scopes=["tweet.read", "tweet.write", "users.read", "offline.access"],
|
||||
title="Mock Twitter OAuth2 Credentials",
|
||||
username="mock-twitter-username",
|
||||
refresh_token_expires_at=1234567890,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
@@ -1,418 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from backend.blocks.twitter._mappers import (
|
||||
get_backend_expansion,
|
||||
get_backend_field,
|
||||
get_backend_list_expansion,
|
||||
get_backend_list_field,
|
||||
get_backend_media_field,
|
||||
get_backend_place_field,
|
||||
get_backend_poll_field,
|
||||
get_backend_space_expansion,
|
||||
get_backend_space_field,
|
||||
get_backend_user_field,
|
||||
)
|
||||
from backend.blocks.twitter._types import ( # DMEventFieldFilter,
|
||||
DMEventExpansionFilter,
|
||||
DMEventTypeFilter,
|
||||
DMMediaFieldFilter,
|
||||
DMTweetFieldFilter,
|
||||
ExpansionFilter,
|
||||
ListExpansionsFilter,
|
||||
ListFieldsFilter,
|
||||
SpaceExpansionsFilter,
|
||||
SpaceFieldsFilter,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetReplySettingsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
|
||||
|
||||
# Common Builder
|
||||
class TweetExpansionsBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_expansions(self, expansions: ExpansionFilter | None):
|
||||
if expansions:
|
||||
filtered_expansions = [
|
||||
name for name, value in expansions.dict().items() if value is True
|
||||
]
|
||||
|
||||
if filtered_expansions:
|
||||
self.params["expansions"] = ",".join(
|
||||
[get_backend_expansion(exp) for exp in filtered_expansions]
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def add_media_fields(self, media_fields: TweetMediaFieldsFilter | None):
|
||||
if media_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in media_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["media.fields"] = ",".join(
|
||||
[get_backend_media_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_place_fields(self, place_fields: TweetPlaceFieldsFilter | None):
|
||||
if place_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in place_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["place.fields"] = ",".join(
|
||||
[get_backend_place_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_poll_fields(self, poll_fields: TweetPollFieldsFilter | None):
|
||||
if poll_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in poll_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["poll.fields"] = ",".join(
|
||||
[get_backend_poll_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_tweet_fields(self, tweet_fields: TweetFieldsFilter | None):
|
||||
if tweet_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in tweet_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["tweet.fields"] = ",".join(
|
||||
[get_backend_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
|
||||
if user_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in user_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["user.fields"] = ",".join(
|
||||
[get_backend_user_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class UserExpansionsBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_expansions(self, expansions: UserExpansionsFilter | None):
|
||||
if expansions:
|
||||
filtered_expansions = [
|
||||
name for name, value in expansions.dict().items() if value is True
|
||||
]
|
||||
if filtered_expansions:
|
||||
self.params["expansions"] = ",".join(filtered_expansions)
|
||||
return self
|
||||
|
||||
def add_tweet_fields(self, tweet_fields: TweetFieldsFilter | None):
|
||||
if tweet_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in tweet_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["tweet.fields"] = ",".join(
|
||||
[get_backend_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
|
||||
if user_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in user_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["user.fields"] = ",".join(
|
||||
[get_backend_user_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class ListExpansionsBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_expansions(self, expansions: ListExpansionsFilter | None):
|
||||
if expansions:
|
||||
filtered_expansions = [
|
||||
name for name, value in expansions.dict().items() if value is True
|
||||
]
|
||||
if filtered_expansions:
|
||||
self.params["expansions"] = ",".join(
|
||||
[get_backend_list_expansion(exp) for exp in filtered_expansions]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_list_fields(self, list_fields: ListFieldsFilter | None):
|
||||
if list_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in list_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["list.fields"] = ",".join(
|
||||
[get_backend_list_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
|
||||
if user_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in user_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["user.fields"] = ",".join(
|
||||
[get_backend_user_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class SpaceExpansionsBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_expansions(self, expansions: SpaceExpansionsFilter | None):
|
||||
if expansions:
|
||||
filtered_expansions = [
|
||||
name for name, value in expansions.dict().items() if value is True
|
||||
]
|
||||
if filtered_expansions:
|
||||
self.params["expansions"] = ",".join(
|
||||
[get_backend_space_expansion(exp) for exp in filtered_expansions]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_space_fields(self, space_fields: SpaceFieldsFilter | None):
|
||||
if space_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in space_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["space.fields"] = ",".join(
|
||||
[get_backend_space_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def add_user_fields(self, user_fields: TweetUserFieldsFilter | None):
|
||||
if user_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in user_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["user.fields"] = ",".join(
|
||||
[get_backend_user_field(field) for field in filtered_fields]
|
||||
)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class TweetDurationBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_start_time(self, start_time: datetime | None):
|
||||
if start_time:
|
||||
self.params["start_time"] = start_time
|
||||
return self
|
||||
|
||||
def add_end_time(self, end_time: datetime | None):
|
||||
if end_time:
|
||||
self.params["end_time"] = end_time
|
||||
return self
|
||||
|
||||
def add_since_id(self, since_id: str | None):
|
||||
if since_id:
|
||||
self.params["since_id"] = since_id
|
||||
return self
|
||||
|
||||
def add_until_id(self, until_id: str | None):
|
||||
if until_id:
|
||||
self.params["until_id"] = until_id
|
||||
return self
|
||||
|
||||
def add_sort_order(self, sort_order: str | None):
|
||||
if sort_order:
|
||||
self.params["sort_order"] = sort_order
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class DMExpansionsBuilder:
|
||||
def __init__(self, param: Dict[str, Any]):
|
||||
self.params: Dict[str, Any] = param
|
||||
|
||||
def add_expansions(self, expansions: DMEventExpansionFilter):
|
||||
if expansions:
|
||||
filtered_expansions = [
|
||||
name for name, value in expansions.dict().items() if value is True
|
||||
]
|
||||
if filtered_expansions:
|
||||
self.params["expansions"] = ",".join(filtered_expansions)
|
||||
return self
|
||||
|
||||
def add_event_types(self, event_types: DMEventTypeFilter):
|
||||
if event_types:
|
||||
filtered_types = [
|
||||
name for name, value in event_types.dict().items() if value is True
|
||||
]
|
||||
if filtered_types:
|
||||
self.params["event_types"] = ",".join(filtered_types)
|
||||
return self
|
||||
|
||||
def add_media_fields(self, media_fields: DMMediaFieldFilter):
|
||||
if media_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in media_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["media.fields"] = ",".join(filtered_fields)
|
||||
return self
|
||||
|
||||
def add_tweet_fields(self, tweet_fields: DMTweetFieldFilter):
|
||||
if tweet_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in tweet_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["tweet.fields"] = ",".join(filtered_fields)
|
||||
return self
|
||||
|
||||
def add_user_fields(self, user_fields: TweetUserFieldsFilter):
|
||||
if user_fields:
|
||||
filtered_fields = [
|
||||
name for name, value in user_fields.dict().items() if value is True
|
||||
]
|
||||
if filtered_fields:
|
||||
self.params["user.fields"] = ",".join(filtered_fields)
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
# Specific Builders
|
||||
class TweetSearchBuilder:
|
||||
def __init__(self):
|
||||
self.params: Dict[str, Any] = {"user_auth": False}
|
||||
|
||||
def add_query(self, query: str):
|
||||
if query:
|
||||
self.params["query"] = query
|
||||
return self
|
||||
|
||||
def add_pagination(self, max_results: int, pagination: str | None):
|
||||
if max_results:
|
||||
self.params["max_results"] = max_results
|
||||
if pagination:
|
||||
self.params["pagination_token"] = pagination
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class TweetPostBuilder:
|
||||
def __init__(self):
|
||||
self.params: Dict[str, Any] = {"user_auth": False}
|
||||
|
||||
def add_text(self, text: str | None):
|
||||
if text:
|
||||
self.params["text"] = text
|
||||
return self
|
||||
|
||||
def add_media(self, media_ids: list, tagged_user_ids: list):
|
||||
if media_ids:
|
||||
self.params["media_ids"] = media_ids
|
||||
if tagged_user_ids:
|
||||
self.params["media_tagged_user_ids"] = tagged_user_ids
|
||||
return self
|
||||
|
||||
def add_deep_link(self, link: str):
|
||||
if link:
|
||||
self.params["direct_message_deep_link"] = link
|
||||
return self
|
||||
|
||||
def add_super_followers(self, for_super_followers: bool):
|
||||
if for_super_followers:
|
||||
self.params["for_super_followers_only"] = for_super_followers
|
||||
return self
|
||||
|
||||
def add_place(self, place_id: str):
|
||||
if place_id:
|
||||
self.params["place_id"] = place_id
|
||||
return self
|
||||
|
||||
def add_poll_options(self, poll_options: list):
|
||||
if poll_options:
|
||||
self.params["poll_options"] = poll_options
|
||||
return self
|
||||
|
||||
def add_poll_duration(self, poll_duration_minutes: int):
|
||||
if poll_duration_minutes:
|
||||
self.params["poll_duration_minutes"] = poll_duration_minutes
|
||||
return self
|
||||
|
||||
def add_quote(self, quote_id: str):
|
||||
if quote_id:
|
||||
self.params["quote_tweet_id"] = quote_id
|
||||
return self
|
||||
|
||||
def add_reply_settings(
|
||||
self,
|
||||
exclude_user_ids: list,
|
||||
reply_to_id: str,
|
||||
settings: TweetReplySettingsFilter,
|
||||
):
|
||||
if exclude_user_ids:
|
||||
self.params["exclude_reply_user_ids"] = exclude_user_ids
|
||||
if reply_to_id:
|
||||
self.params["in_reply_to_tweet_id"] = reply_to_id
|
||||
if settings.All_Users:
|
||||
self.params["reply_settings"] = None
|
||||
elif settings.Following_Users_Only:
|
||||
self.params["reply_settings"] = "following"
|
||||
elif settings.Mentioned_Users_Only:
|
||||
self.params["reply_settings"] = "mentionedUsers"
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
|
||||
|
||||
class TweetGetsBuilder:
|
||||
def __init__(self):
|
||||
self.params: Dict[str, Any] = {"user_auth": False}
|
||||
|
||||
def add_id(self, tweet_id: list[str]):
|
||||
self.params["id"] = tweet_id
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self.params
|
||||
@@ -1,234 +0,0 @@
|
||||
# -------------- Tweets -----------------
|
||||
|
||||
# Tweet Expansions
|
||||
EXPANSION_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Poll_IDs": "attachments.poll_ids",
|
||||
"Media_Keys": "attachments.media_keys",
|
||||
"Author_User_ID": "author_id",
|
||||
"Edit_History_Tweet_IDs": "edit_history_tweet_ids",
|
||||
"Mentioned_Usernames": "entities.mentions.username",
|
||||
"Place_ID": "geo.place_id",
|
||||
"Reply_To_User_ID": "in_reply_to_user_id",
|
||||
"Referenced_Tweet_ID": "referenced_tweets.id",
|
||||
"Referenced_Tweet_Author_ID": "referenced_tweets.id.author_id",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_expansion(frontend_key: str) -> str:
|
||||
result = EXPANSION_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid expansion key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# TweetReplySettings
|
||||
REPLY_SETTINGS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Mentioned_Users_Only": "mentionedUsers",
|
||||
"Following_Users_Only": "following",
|
||||
"All_Users": "all",
|
||||
}
|
||||
|
||||
|
||||
# TweetUserFields
|
||||
def get_backend_reply_setting(frontend_key: str) -> str:
|
||||
result = REPLY_SETTINGS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid reply setting key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
USER_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Account_Creation_Date": "created_at",
|
||||
"User_Bio": "description",
|
||||
"User_Entities": "entities",
|
||||
"User_ID": "id",
|
||||
"User_Location": "location",
|
||||
"Latest_Tweet_ID": "most_recent_tweet_id",
|
||||
"Display_Name": "name",
|
||||
"Pinned_Tweet_ID": "pinned_tweet_id",
|
||||
"Profile_Picture_URL": "profile_image_url",
|
||||
"Is_Protected_Account": "protected",
|
||||
"Account_Statistics": "public_metrics",
|
||||
"Profile_URL": "url",
|
||||
"Username": "username",
|
||||
"Is_Verified": "verified",
|
||||
"Verification_Type": "verified_type",
|
||||
"Content_Withholding_Info": "withheld",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_user_field(frontend_key: str) -> str:
|
||||
result = USER_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid user field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# TweetFields
|
||||
FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Tweet_Attachments": "attachments",
|
||||
"Author_ID": "author_id",
|
||||
"Context_Annotations": "context_annotations",
|
||||
"Conversation_ID": "conversation_id",
|
||||
"Creation_Time": "created_at",
|
||||
"Edit_Controls": "edit_controls",
|
||||
"Tweet_Entities": "entities",
|
||||
"Geographic_Location": "geo",
|
||||
"Tweet_ID": "id",
|
||||
"Reply_To_User_ID": "in_reply_to_user_id",
|
||||
"Language": "lang",
|
||||
"Public_Metrics": "public_metrics",
|
||||
"Sensitive_Content_Flag": "possibly_sensitive",
|
||||
"Referenced_Tweets": "referenced_tweets",
|
||||
"Reply_Settings": "reply_settings",
|
||||
"Tweet_Source": "source",
|
||||
"Tweet_Text": "text",
|
||||
"Withheld_Content": "withheld",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_field(frontend_key: str) -> str:
|
||||
result = FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# TweetPollFields
|
||||
POLL_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Duration_Minutes": "duration_minutes",
|
||||
"End_DateTime": "end_datetime",
|
||||
"Poll_ID": "id",
|
||||
"Poll_Options": "options",
|
||||
"Voting_Status": "voting_status",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_poll_field(frontend_key: str) -> str:
|
||||
result = POLL_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid poll field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
PLACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Contained_Within_Places": "contained_within",
|
||||
"Country": "country",
|
||||
"Country_Code": "country_code",
|
||||
"Full_Location_Name": "full_name",
|
||||
"Geographic_Coordinates": "geo",
|
||||
"Place_ID": "id",
|
||||
"Place_Name": "name",
|
||||
"Place_Type": "place_type",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_place_field(frontend_key: str) -> str:
|
||||
result = PLACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid place field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# TweetMediaFields
|
||||
MEDIA_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Duration_in_Milliseconds": "duration_ms",
|
||||
"Height": "height",
|
||||
"Media_Key": "media_key",
|
||||
"Preview_Image_URL": "preview_image_url",
|
||||
"Media_Type": "type",
|
||||
"Media_URL": "url",
|
||||
"Width": "width",
|
||||
"Public_Metrics": "public_metrics",
|
||||
"Non_Public_Metrics": "non_public_metrics",
|
||||
"Organic_Metrics": "organic_metrics",
|
||||
"Promoted_Metrics": "promoted_metrics",
|
||||
"Alternative_Text": "alt_text",
|
||||
"Media_Variants": "variants",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_media_field(frontend_key: str) -> str:
|
||||
result = MEDIA_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid media field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# -------------- Spaces -----------------
|
||||
|
||||
# SpaceExpansions
|
||||
EXPANSION_FRONTEND_TO_BACKEND_MAPPING_SPACE = {
|
||||
"Invited_Users": "invited_user_ids",
|
||||
"Speakers": "speaker_ids",
|
||||
"Creator": "creator_id",
|
||||
"Hosts": "host_ids",
|
||||
"Topics": "topic_ids",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_space_expansion(frontend_key: str) -> str:
|
||||
result = EXPANSION_FRONTEND_TO_BACKEND_MAPPING_SPACE.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid expansion key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# SpaceFields
|
||||
SPACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"Space_ID": "id",
|
||||
"Space_State": "state",
|
||||
"Creation_Time": "created_at",
|
||||
"End_Time": "ended_at",
|
||||
"Host_User_IDs": "host_ids",
|
||||
"Language": "lang",
|
||||
"Is_Ticketed": "is_ticketed",
|
||||
"Invited_User_IDs": "invited_user_ids",
|
||||
"Participant_Count": "participant_count",
|
||||
"Subscriber_Count": "subscriber_count",
|
||||
"Scheduled_Start_Time": "scheduled_start",
|
||||
"Speaker_User_IDs": "speaker_ids",
|
||||
"Start_Time": "started_at",
|
||||
"Space_Title": "title",
|
||||
"Topic_IDs": "topic_ids",
|
||||
"Last_Updated_Time": "updated_at",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_space_field(frontend_key: str) -> str:
|
||||
result = SPACE_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid space field key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
# -------------- List Expansions -----------------
|
||||
|
||||
# ListExpansions
|
||||
LIST_EXPANSION_FRONTEND_TO_BACKEND_MAPPING = {"List_Owner_ID": "owner_id"}
|
||||
|
||||
|
||||
def get_backend_list_expansion(frontend_key: str) -> str:
|
||||
result = LIST_EXPANSION_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid list expansion key: {frontend_key}")
|
||||
return result
|
||||
|
||||
|
||||
LIST_FIELDS_FRONTEND_TO_BACKEND_MAPPING = {
|
||||
"List_ID": "id",
|
||||
"List_Name": "name",
|
||||
"Creation_Date": "created_at",
|
||||
"Description": "description",
|
||||
"Follower_Count": "follower_count",
|
||||
"Member_Count": "member_count",
|
||||
"Is_Private": "private",
|
||||
"Owner_ID": "owner_id",
|
||||
}
|
||||
|
||||
|
||||
def get_backend_list_field(frontend_key: str) -> str:
|
||||
result = LIST_FIELDS_FRONTEND_TO_BACKEND_MAPPING.get(frontend_key)
|
||||
if result is None:
|
||||
raise KeyError(f"Invalid list field key: {frontend_key}")
|
||||
return result
|
||||
@@ -1,76 +0,0 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class BaseSerializer:
|
||||
@staticmethod
|
||||
def _serialize_value(value: Any) -> Any:
|
||||
"""Helper method to serialize individual values"""
|
||||
if hasattr(value, "data"):
|
||||
return value.data
|
||||
return value
|
||||
|
||||
|
||||
class IncludesSerializer(BaseSerializer):
|
||||
@classmethod
|
||||
def serialize(cls, includes: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Serializes the includes dictionary"""
|
||||
if not includes:
|
||||
return {}
|
||||
|
||||
serialized_includes = {}
|
||||
for key, value in includes.items():
|
||||
if isinstance(value, list):
|
||||
serialized_includes[key] = [
|
||||
cls._serialize_value(item) for item in value
|
||||
]
|
||||
else:
|
||||
serialized_includes[key] = cls._serialize_value(value)
|
||||
|
||||
return serialized_includes
|
||||
|
||||
|
||||
class ResponseDataSerializer(BaseSerializer):
|
||||
@classmethod
|
||||
def serialize_dict(cls, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Serializes a single dictionary item"""
|
||||
serialized_item = {}
|
||||
|
||||
if hasattr(item, "__dict__"):
|
||||
items = item.__dict__.items()
|
||||
else:
|
||||
items = item.items()
|
||||
|
||||
for key, value in items:
|
||||
if isinstance(value, list):
|
||||
serialized_item[key] = [
|
||||
cls._serialize_value(sub_item) for sub_item in value
|
||||
]
|
||||
else:
|
||||
serialized_item[key] = cls._serialize_value(value)
|
||||
|
||||
return serialized_item
|
||||
|
||||
@classmethod
|
||||
def serialize_list(cls, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Serializes a list of dictionary items"""
|
||||
return [cls.serialize_dict(item) for item in data]
|
||||
|
||||
|
||||
class ResponseSerializer:
|
||||
@classmethod
|
||||
def serialize(cls, response) -> Dict[str, Any]:
|
||||
"""Main serializer that handles both data and includes"""
|
||||
result = {"data": None, "included": {}}
|
||||
|
||||
# Handle response.data
|
||||
if response.data:
|
||||
if isinstance(response.data, list):
|
||||
result["data"] = ResponseDataSerializer.serialize_list(response.data)
|
||||
else:
|
||||
result["data"] = ResponseDataSerializer.serialize_dict(response.data)
|
||||
|
||||
# Handle includes
|
||||
if hasattr(response, "includes") and response.includes:
|
||||
result["included"] = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
return result
|
||||
@@ -1,443 +0,0 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
# -------------- Tweets -----------------
|
||||
|
||||
|
||||
class TweetReplySettingsFilter(BaseModel):
|
||||
Mentioned_Users_Only: bool = False
|
||||
Following_Users_Only: bool = False
|
||||
All_Users: bool = False
|
||||
|
||||
|
||||
class TweetUserFieldsFilter(BaseModel):
|
||||
Account_Creation_Date: bool = False
|
||||
User_Bio: bool = False
|
||||
User_Entities: bool = False
|
||||
User_ID: bool = False
|
||||
User_Location: bool = False
|
||||
Latest_Tweet_ID: bool = False
|
||||
Display_Name: bool = False
|
||||
Pinned_Tweet_ID: bool = False
|
||||
Profile_Picture_URL: bool = False
|
||||
Is_Protected_Account: bool = False
|
||||
Account_Statistics: bool = False
|
||||
Profile_URL: bool = False
|
||||
Username: bool = False
|
||||
Is_Verified: bool = False
|
||||
Verification_Type: bool = False
|
||||
Content_Withholding_Info: bool = False
|
||||
|
||||
|
||||
class TweetFieldsFilter(BaseModel):
|
||||
Tweet_Attachments: bool = False
|
||||
Author_ID: bool = False
|
||||
Context_Annotations: bool = False
|
||||
Conversation_ID: bool = False
|
||||
Creation_Time: bool = False
|
||||
Edit_Controls: bool = False
|
||||
Tweet_Entities: bool = False
|
||||
Geographic_Location: bool = False
|
||||
Tweet_ID: bool = False
|
||||
Reply_To_User_ID: bool = False
|
||||
Language: bool = False
|
||||
Public_Metrics: bool = False
|
||||
Sensitive_Content_Flag: bool = False
|
||||
Referenced_Tweets: bool = False
|
||||
Reply_Settings: bool = False
|
||||
Tweet_Source: bool = False
|
||||
Tweet_Text: bool = False
|
||||
Withheld_Content: bool = False
|
||||
|
||||
|
||||
class PersonalTweetFieldsFilter(BaseModel):
|
||||
attachments: bool = False
|
||||
author_id: bool = False
|
||||
context_annotations: bool = False
|
||||
conversation_id: bool = False
|
||||
created_at: bool = False
|
||||
edit_controls: bool = False
|
||||
entities: bool = False
|
||||
geo: bool = False
|
||||
id: bool = False
|
||||
in_reply_to_user_id: bool = False
|
||||
lang: bool = False
|
||||
non_public_metrics: bool = False
|
||||
public_metrics: bool = False
|
||||
organic_metrics: bool = False
|
||||
promoted_metrics: bool = False
|
||||
possibly_sensitive: bool = False
|
||||
referenced_tweets: bool = False
|
||||
reply_settings: bool = False
|
||||
source: bool = False
|
||||
text: bool = False
|
||||
withheld: bool = False
|
||||
|
||||
|
||||
class TweetPollFieldsFilter(BaseModel):
|
||||
Duration_Minutes: bool = False
|
||||
End_DateTime: bool = False
|
||||
Poll_ID: bool = False
|
||||
Poll_Options: bool = False
|
||||
Voting_Status: bool = False
|
||||
|
||||
|
||||
class TweetPlaceFieldsFilter(BaseModel):
|
||||
Contained_Within_Places: bool = False
|
||||
Country: bool = False
|
||||
Country_Code: bool = False
|
||||
Full_Location_Name: bool = False
|
||||
Geographic_Coordinates: bool = False
|
||||
Place_ID: bool = False
|
||||
Place_Name: bool = False
|
||||
Place_Type: bool = False
|
||||
|
||||
|
||||
class TweetMediaFieldsFilter(BaseModel):
|
||||
Duration_in_Milliseconds: bool = False
|
||||
Height: bool = False
|
||||
Media_Key: bool = False
|
||||
Preview_Image_URL: bool = False
|
||||
Media_Type: bool = False
|
||||
Media_URL: bool = False
|
||||
Width: bool = False
|
||||
Public_Metrics: bool = False
|
||||
Non_Public_Metrics: bool = False
|
||||
Organic_Metrics: bool = False
|
||||
Promoted_Metrics: bool = False
|
||||
Alternative_Text: bool = False
|
||||
Media_Variants: bool = False
|
||||
|
||||
|
||||
class ExpansionFilter(BaseModel):
|
||||
Poll_IDs: bool = False
|
||||
Media_Keys: bool = False
|
||||
Author_User_ID: bool = False
|
||||
Edit_History_Tweet_IDs: bool = False
|
||||
Mentioned_Usernames: bool = False
|
||||
Place_ID: bool = False
|
||||
Reply_To_User_ID: bool = False
|
||||
Referenced_Tweet_ID: bool = False
|
||||
Referenced_Tweet_Author_ID: bool = False
|
||||
|
||||
|
||||
class TweetExcludesFilter(BaseModel):
|
||||
retweets: bool = False
|
||||
replies: bool = False
|
||||
|
||||
|
||||
# -------------- Users -----------------
|
||||
|
||||
|
||||
class UserExpansionsFilter(BaseModel):
|
||||
pinned_tweet_id: bool = False
|
||||
|
||||
|
||||
# -------------- DM's' -----------------
|
||||
|
||||
|
||||
class DMEventFieldFilter(BaseModel):
|
||||
id: bool = False
|
||||
text: bool = False
|
||||
event_type: bool = False
|
||||
created_at: bool = False
|
||||
dm_conversation_id: bool = False
|
||||
sender_id: bool = False
|
||||
participant_ids: bool = False
|
||||
referenced_tweets: bool = False
|
||||
attachments: bool = False
|
||||
|
||||
|
||||
class DMEventTypeFilter(BaseModel):
|
||||
MessageCreate: bool = False
|
||||
ParticipantsJoin: bool = False
|
||||
ParticipantsLeave: bool = False
|
||||
|
||||
|
||||
class DMEventExpansionFilter(BaseModel):
|
||||
attachments_media_keys: bool = False
|
||||
referenced_tweets_id: bool = False
|
||||
sender_id: bool = False
|
||||
participant_ids: bool = False
|
||||
|
||||
|
||||
class DMMediaFieldFilter(BaseModel):
|
||||
duration_ms: bool = False
|
||||
height: bool = False
|
||||
media_key: bool = False
|
||||
preview_image_url: bool = False
|
||||
type: bool = False
|
||||
url: bool = False
|
||||
width: bool = False
|
||||
public_metrics: bool = False
|
||||
alt_text: bool = False
|
||||
variants: bool = False
|
||||
|
||||
|
||||
class DMTweetFieldFilter(BaseModel):
|
||||
attachments: bool = False
|
||||
author_id: bool = False
|
||||
context_annotations: bool = False
|
||||
conversation_id: bool = False
|
||||
created_at: bool = False
|
||||
edit_controls: bool = False
|
||||
entities: bool = False
|
||||
geo: bool = False
|
||||
id: bool = False
|
||||
in_reply_to_user_id: bool = False
|
||||
lang: bool = False
|
||||
public_metrics: bool = False
|
||||
possibly_sensitive: bool = False
|
||||
referenced_tweets: bool = False
|
||||
reply_settings: bool = False
|
||||
source: bool = False
|
||||
text: bool = False
|
||||
withheld: bool = False
|
||||
|
||||
|
||||
# -------------- Spaces -----------------
|
||||
|
||||
|
||||
class SpaceExpansionsFilter(BaseModel):
|
||||
Invited_Users: bool = False
|
||||
Speakers: bool = False
|
||||
Creator: bool = False
|
||||
Hosts: bool = False
|
||||
Topics: bool = False
|
||||
|
||||
|
||||
class SpaceFieldsFilter(BaseModel):
|
||||
Space_ID: bool = False
|
||||
Space_State: bool = False
|
||||
Creation_Time: bool = False
|
||||
End_Time: bool = False
|
||||
Host_User_IDs: bool = False
|
||||
Language: bool = False
|
||||
Is_Ticketed: bool = False
|
||||
Invited_User_IDs: bool = False
|
||||
Participant_Count: bool = False
|
||||
Subscriber_Count: bool = False
|
||||
Scheduled_Start_Time: bool = False
|
||||
Speaker_User_IDs: bool = False
|
||||
Start_Time: bool = False
|
||||
Space_Title: bool = False
|
||||
Topic_IDs: bool = False
|
||||
Last_Updated_Time: bool = False
|
||||
|
||||
|
||||
class SpaceStatesFilter(str, Enum):
|
||||
live = "live"
|
||||
scheduled = "scheduled"
|
||||
all = "all"
|
||||
|
||||
|
||||
# -------------- List Expansions -----------------
|
||||
|
||||
|
||||
class ListExpansionsFilter(BaseModel):
|
||||
List_Owner_ID: bool = False
|
||||
|
||||
|
||||
class ListFieldsFilter(BaseModel):
|
||||
List_ID: bool = False
|
||||
List_Name: bool = False
|
||||
Creation_Date: bool = False
|
||||
Description: bool = False
|
||||
Follower_Count: bool = False
|
||||
Member_Count: bool = False
|
||||
Is_Private: bool = False
|
||||
Owner_ID: bool = False
|
||||
|
||||
|
||||
# --------- [Input Types] -------------
|
||||
class TweetExpansionInputs(BlockSchema):
|
||||
|
||||
expansions: ExpansionFilter | None = SchemaField(
|
||||
description="Choose what extra information you want to get with your tweets. For example:\n- Select 'Media_Keys' to get media details\n- Select 'Author_User_ID' to get user information\n- Select 'Place_ID' to get location details",
|
||||
placeholder="Pick the extra information you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
media_fields: TweetMediaFieldsFilter | None = SchemaField(
|
||||
description="Select what media information you want to see (images, videos, etc). To use this, you must first select 'Media_Keys' in the expansions above.",
|
||||
placeholder="Choose what media details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
place_fields: TweetPlaceFieldsFilter | None = SchemaField(
|
||||
description="Select what location information you want to see (country, coordinates, etc). To use this, you must first select 'Place_ID' in the expansions above.",
|
||||
placeholder="Choose what location details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
poll_fields: TweetPollFieldsFilter | None = SchemaField(
|
||||
description="Select what poll information you want to see (options, voting status, etc). To use this, you must first select 'Poll_IDs' in the expansions above.",
|
||||
placeholder="Choose what poll details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tweet_fields: TweetFieldsFilter | None = SchemaField(
|
||||
description="Select what tweet information you want to see. For referenced tweets (like retweets), select 'Referenced_Tweet_ID' in the expansions above.",
|
||||
placeholder="Choose what tweet details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
user_fields: TweetUserFieldsFilter | None = SchemaField(
|
||||
description="Select what user information you want to see. To use this, you must first select one of these in expansions above:\n- 'Author_User_ID' for tweet authors\n- 'Mentioned_Usernames' for mentioned users\n- 'Reply_To_User_ID' for users being replied to\n- 'Referenced_Tweet_Author_ID' for authors of referenced tweets",
|
||||
placeholder="Choose what user details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class DMEventExpansionInputs(BlockSchema):
|
||||
expansions: DMEventExpansionFilter | None = SchemaField(
|
||||
description="Select expansions to include related data objects in the 'includes' section.",
|
||||
placeholder="Enter expansions",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
event_types: DMEventTypeFilter | None = SchemaField(
|
||||
description="Select DM event types to include in the response.",
|
||||
placeholder="Enter event types",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
media_fields: DMMediaFieldFilter | None = SchemaField(
|
||||
description="Select media fields to include in the response (requires expansions=attachments.media_keys).",
|
||||
placeholder="Enter media fields",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tweet_fields: DMTweetFieldFilter | None = SchemaField(
|
||||
description="Select tweet fields to include in the response (requires expansions=referenced_tweets.id).",
|
||||
placeholder="Enter tweet fields",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
user_fields: TweetUserFieldsFilter | None = SchemaField(
|
||||
description="Select user fields to include in the response (requires expansions=sender_id or participant_ids).",
|
||||
placeholder="Enter user fields",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class UserExpansionInputs(BlockSchema):
|
||||
expansions: UserExpansionsFilter | None = SchemaField(
|
||||
description="Choose what extra information you want to get with user data. Currently only 'pinned_tweet_id' is available to see a user's pinned tweet.",
|
||||
placeholder="Select extra user information to include",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
tweet_fields: TweetFieldsFilter | None = SchemaField(
|
||||
description="Select what tweet information you want to see in pinned tweets. This only works if you select 'pinned_tweet_id' in expansions above.",
|
||||
placeholder="Choose what details to see in pinned tweets",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
user_fields: TweetUserFieldsFilter | None = SchemaField(
|
||||
description="Select what user information you want to see, like username, bio, profile picture, etc.",
|
||||
placeholder="Choose what user details you want to see",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class SpaceExpansionInputs(BlockSchema):
|
||||
expansions: SpaceExpansionsFilter | None = SchemaField(
|
||||
description="Choose additional information you want to get with your Twitter Spaces:\n- Select 'Invited_Users' to see who was invited\n- Select 'Speakers' to see who can speak\n- Select 'Creator' to get details about who made the Space\n- Select 'Hosts' to see who's hosting\n- Select 'Topics' to see Space topics",
|
||||
placeholder="Pick what extra information you want to see about the Space",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
space_fields: SpaceFieldsFilter | None = SchemaField(
|
||||
description="Choose what Space details you want to see, such as:\n- Title\n- Start/End times\n- Number of participants\n- Language\n- State (live/scheduled)\n- And more",
|
||||
placeholder="Choose what Space information you want to get",
|
||||
default=SpaceFieldsFilter(Space_Title=True, Host_User_IDs=True),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
user_fields: TweetUserFieldsFilter | None = SchemaField(
|
||||
description="Choose what user information you want to see. This works when you select any of these in expansions above:\n- 'Creator' for Space creator details\n- 'Hosts' for host information\n- 'Speakers' for speaker details\n- 'Invited_Users' for invited user information",
|
||||
placeholder="Pick what details you want to see about the users",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class ListExpansionInputs(BlockSchema):
|
||||
expansions: ListExpansionsFilter | None = SchemaField(
|
||||
description="Choose what extra information you want to get with your Twitter Lists:\n- Select 'List_Owner_ID' to get details about who owns the list\n\nThis will let you see more details about the list owner when you also select user fields below.",
|
||||
placeholder="Pick what extra list information you want to see",
|
||||
default=ListExpansionsFilter(List_Owner_ID=True),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
user_fields: TweetUserFieldsFilter | None = SchemaField(
|
||||
description="Choose what information you want to see about list owners. This only works when you select 'List_Owner_ID' in expansions above.\n\nYou can see things like:\n- Their username\n- Profile picture\n- Account details\n- And more",
|
||||
placeholder="Select what details you want to see about list owners",
|
||||
default=TweetUserFieldsFilter(User_ID=True, Username=True),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
list_fields: ListFieldsFilter | None = SchemaField(
|
||||
description="Choose what information you want to see about the Twitter Lists themselves, such as:\n- List name\n- Description\n- Number of followers\n- Number of members\n- Whether it's private\n- Creation date\n- And more",
|
||||
placeholder="Pick what list details you want to see",
|
||||
default=ListFieldsFilter(Owner_ID=True),
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
|
||||
class TweetTimeWindowInputs(BlockSchema):
|
||||
start_time: datetime | None = SchemaField(
|
||||
description="Start time in YYYY-MM-DDTHH:mm:ssZ format",
|
||||
placeholder="Enter start time",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
end_time: datetime | None = SchemaField(
|
||||
description="End time in YYYY-MM-DDTHH:mm:ssZ format",
|
||||
placeholder="Enter end time",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
since_id: str | None = SchemaField(
|
||||
description="Returns results with Tweet ID greater than this (more recent than), we give priority to since_id over start_time",
|
||||
placeholder="Enter since ID",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
until_id: str | None = SchemaField(
|
||||
description="Returns results with Tweet ID less than this (that is, older than), and used with since_id",
|
||||
placeholder="Enter until ID",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
sort_order: str | None = SchemaField(
|
||||
description="Order of returned tweets (recency or relevancy)",
|
||||
placeholder="Enter sort order",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
@@ -1,201 +0,0 @@
|
||||
# Todo : Add new Type support
|
||||
|
||||
# from typing import cast
|
||||
# import tweepy
|
||||
# from tweepy.client import Response
|
||||
|
||||
# from backend.blocks.twitter._serializer import IncludesSerializer, ResponseDataSerializer
|
||||
# from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
# from backend.data.model import SchemaField
|
||||
# from backend.blocks.twitter._builders import DMExpansionsBuilder
|
||||
# from backend.blocks.twitter._types import DMEventExpansion, DMEventExpansionInputs, DMEventType, DMMediaField, DMTweetField, TweetUserFields
|
||||
# from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
# from backend.blocks.twitter._auth import (
|
||||
# TEST_CREDENTIALS,
|
||||
# TEST_CREDENTIALS_INPUT,
|
||||
# TwitterCredentials,
|
||||
# TwitterCredentialsField,
|
||||
# TwitterCredentialsInput,
|
||||
# )
|
||||
|
||||
# Require Pro or Enterprise plan [Manual Testing Required]
|
||||
# class TwitterGetDMEventsBlock(Block):
|
||||
# """
|
||||
# Gets a list of Direct Message events for the authenticated user
|
||||
# """
|
||||
|
||||
# class Input(DMEventExpansionInputs):
|
||||
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
# ["dm.read", "offline.access", "user.read", "tweet.read"]
|
||||
# )
|
||||
|
||||
# dm_conversation_id: str = SchemaField(
|
||||
# description="The ID of the Direct Message conversation",
|
||||
# placeholder="Enter conversation ID",
|
||||
# required=True
|
||||
# )
|
||||
|
||||
# max_results: int = SchemaField(
|
||||
# description="Maximum number of results to return (1-100)",
|
||||
# placeholder="Enter max results",
|
||||
# advanced=True,
|
||||
# default=10,
|
||||
# )
|
||||
|
||||
# pagination_token: str = SchemaField(
|
||||
# description="Token for pagination",
|
||||
# placeholder="Enter pagination token",
|
||||
# advanced=True,
|
||||
# default=""
|
||||
# )
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# # Common outputs
|
||||
# event_ids: list[str] = SchemaField(description="DM Event IDs")
|
||||
# event_texts: list[str] = SchemaField(description="DM Event text contents")
|
||||
# event_types: list[str] = SchemaField(description="Types of DM events")
|
||||
# next_token: str = SchemaField(description="Token for next page of results")
|
||||
|
||||
# # Complete outputs
|
||||
# data: list[dict] = SchemaField(description="Complete DM events data")
|
||||
# included: dict = SchemaField(description="Additional data requested via expansions")
|
||||
# meta: dict = SchemaField(description="Metadata about the response")
|
||||
# error: str = SchemaField(description="Error message if request failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="dc37a6d4-a62e-11ef-a3a5-03061375737b",
|
||||
# description="This block retrieves Direct Message events for the authenticated user.",
|
||||
# categories={BlockCategory.SOCIAL},
|
||||
# input_schema=TwitterGetDMEventsBlock.Input,
|
||||
# output_schema=TwitterGetDMEventsBlock.Output,
|
||||
# test_input={
|
||||
# "dm_conversation_id": "1234567890",
|
||||
# "max_results": 10,
|
||||
# "credentials": TEST_CREDENTIALS_INPUT,
|
||||
# "expansions": [],
|
||||
# "event_types": [],
|
||||
# "media_fields": [],
|
||||
# "tweet_fields": [],
|
||||
# "user_fields": []
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("event_ids", ["1346889436626259968"]),
|
||||
# ("event_texts", ["Hello just you..."]),
|
||||
# ("event_types", ["MessageCreate"]),
|
||||
# ("next_token", None),
|
||||
# ("data", [{"id": "1346889436626259968", "text": "Hello just you...", "event_type": "MessageCreate"}]),
|
||||
# ("included", {}),
|
||||
# ("meta", {}),
|
||||
# ("error", "")
|
||||
# ],
|
||||
# test_mock={
|
||||
# "get_dm_events": lambda *args, **kwargs: (
|
||||
# [{"id": "1346889436626259968", "text": "Hello just you...", "event_type": "MessageCreate"}],
|
||||
# {},
|
||||
# {},
|
||||
# ["1346889436626259968"],
|
||||
# ["Hello just you..."],
|
||||
# ["MessageCreate"],
|
||||
# None
|
||||
# )
|
||||
# }
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def get_dm_events(
|
||||
# credentials: TwitterCredentials,
|
||||
# dm_conversation_id: str,
|
||||
# max_results: int,
|
||||
# pagination_token: str,
|
||||
# expansions: list[DMEventExpansion],
|
||||
# event_types: list[DMEventType],
|
||||
# media_fields: list[DMMediaField],
|
||||
# tweet_fields: list[DMTweetField],
|
||||
# user_fields: list[TweetUserFields]
|
||||
# ):
|
||||
# try:
|
||||
# client = tweepy.Client(
|
||||
# bearer_token=credentials.access_token.get_secret_value()
|
||||
# )
|
||||
|
||||
# params = {
|
||||
# "dm_conversation_id": dm_conversation_id,
|
||||
# "max_results": max_results,
|
||||
# "pagination_token": None if pagination_token == "" else pagination_token,
|
||||
# "user_auth": False
|
||||
# }
|
||||
|
||||
# params = (DMExpansionsBuilder(params)
|
||||
# .add_expansions(expansions)
|
||||
# .add_event_types(event_types)
|
||||
# .add_media_fields(media_fields)
|
||||
# .add_tweet_fields(tweet_fields)
|
||||
# .add_user_fields(user_fields)
|
||||
# .build())
|
||||
|
||||
# response = cast(Response, client.get_direct_message_events(**params))
|
||||
|
||||
# meta = {}
|
||||
# event_ids = []
|
||||
# event_texts = []
|
||||
# event_types = []
|
||||
# next_token = None
|
||||
|
||||
# if response.meta:
|
||||
# meta = response.meta
|
||||
# next_token = meta.get("next_token")
|
||||
|
||||
# included = IncludesSerializer.serialize(response.includes)
|
||||
# data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
# if response.data:
|
||||
# event_ids = [str(item.id) for item in response.data]
|
||||
# event_texts = [item.text if hasattr(item, "text") else None for item in response.data]
|
||||
# event_types = [item.event_type for item in response.data]
|
||||
|
||||
# return data, included, meta, event_ids, event_texts, event_types, next_token
|
||||
|
||||
# raise Exception("No DM events found")
|
||||
|
||||
# except tweepy.TweepyException:
|
||||
# raise
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TwitterCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# event_data, included, meta, event_ids, event_texts, event_types, next_token = self.get_dm_events(
|
||||
# credentials,
|
||||
# input_data.dm_conversation_id,
|
||||
# input_data.max_results,
|
||||
# input_data.pagination_token,
|
||||
# input_data.expansions,
|
||||
# input_data.event_types,
|
||||
# input_data.media_fields,
|
||||
# input_data.tweet_fields,
|
||||
# input_data.user_fields
|
||||
# )
|
||||
|
||||
# if event_ids:
|
||||
# yield "event_ids", event_ids
|
||||
# if event_texts:
|
||||
# yield "event_texts", event_texts
|
||||
# if event_types:
|
||||
# yield "event_types", event_types
|
||||
# if next_token:
|
||||
# yield "next_token", next_token
|
||||
# if event_data:
|
||||
# yield "data", event_data
|
||||
# if included:
|
||||
# yield "included", included
|
||||
# if meta:
|
||||
# yield "meta", meta
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,260 +0,0 @@
|
||||
# Todo : Add new Type support
|
||||
|
||||
# from typing import cast
|
||||
|
||||
# import tweepy
|
||||
# from tweepy.client import Response
|
||||
|
||||
# from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
# from backend.data.model import SchemaField
|
||||
# from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
# from backend.blocks.twitter._auth import (
|
||||
# TEST_CREDENTIALS,
|
||||
# TEST_CREDENTIALS_INPUT,
|
||||
# TwitterCredentials,
|
||||
# TwitterCredentialsField,
|
||||
# TwitterCredentialsInput,
|
||||
# )
|
||||
|
||||
# Pro and Enterprise plan [Manual Testing Required]
|
||||
# class TwitterSendDirectMessageBlock(Block):
|
||||
# """
|
||||
# Sends a direct message to a Twitter user
|
||||
# """
|
||||
|
||||
# class Input(BlockSchema):
|
||||
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
# ["offline.access", "direct_messages.write"]
|
||||
# )
|
||||
|
||||
# participant_id: str = SchemaField(
|
||||
# description="The User ID of the account to send DM to",
|
||||
# placeholder="Enter recipient user ID",
|
||||
# default="",
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# dm_conversation_id: str = SchemaField(
|
||||
# description="The conversation ID to send message to",
|
||||
# placeholder="Enter conversation ID",
|
||||
# default="",
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# text: str = SchemaField(
|
||||
# description="Text of the Direct Message (up to 10,000 characters)",
|
||||
# placeholder="Enter message text",
|
||||
# default="",
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# media_id: str = SchemaField(
|
||||
# description="Media ID to attach to the message",
|
||||
# placeholder="Enter media ID",
|
||||
# default=""
|
||||
# )
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# dm_event_id: str = SchemaField(description="ID of the sent direct message")
|
||||
# dm_conversation_id_: str = SchemaField(description="ID of the conversation")
|
||||
# error: str = SchemaField(description="Error message if sending failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="f32f2786-a62e-11ef-a93d-a3ef199dde7f",
|
||||
# description="This block sends a direct message to a specified Twitter user.",
|
||||
# categories={BlockCategory.SOCIAL},
|
||||
# input_schema=TwitterSendDirectMessageBlock.Input,
|
||||
# output_schema=TwitterSendDirectMessageBlock.Output,
|
||||
# test_input={
|
||||
# "participant_id": "783214",
|
||||
# "dm_conversation_id": "",
|
||||
# "text": "Hello from Twitter API",
|
||||
# "media_id": "",
|
||||
# "credentials": TEST_CREDENTIALS_INPUT
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("dm_event_id", "0987654321"),
|
||||
# ("dm_conversation_id_", "1234567890"),
|
||||
# ("error", "")
|
||||
# ],
|
||||
# test_mock={
|
||||
# "send_direct_message": lambda *args, **kwargs: (
|
||||
# "0987654321",
|
||||
# "1234567890"
|
||||
# )
|
||||
# },
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def send_direct_message(
|
||||
# credentials: TwitterCredentials,
|
||||
# participant_id: str,
|
||||
# dm_conversation_id: str,
|
||||
# text: str,
|
||||
# media_id: str
|
||||
# ):
|
||||
# try:
|
||||
# client = tweepy.Client(
|
||||
# bearer_token=credentials.access_token.get_secret_value()
|
||||
# )
|
||||
|
||||
# response = cast(
|
||||
# Response,
|
||||
# client.create_direct_message(
|
||||
# participant_id=None if participant_id == "" else participant_id,
|
||||
# dm_conversation_id=None if dm_conversation_id == "" else dm_conversation_id,
|
||||
# text=None if text == "" else text,
|
||||
# media_id=None if media_id == "" else media_id,
|
||||
# user_auth=False
|
||||
# )
|
||||
# )
|
||||
|
||||
# if not response.data:
|
||||
# raise Exception("Failed to send direct message")
|
||||
|
||||
# return response.data["dm_event_id"], response.data["dm_conversation_id"]
|
||||
|
||||
# except tweepy.TweepyException:
|
||||
# raise
|
||||
# except Exception as e:
|
||||
# print(f"Unexpected error: {str(e)}")
|
||||
# raise
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TwitterCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# dm_event_id, dm_conversation_id = self.send_direct_message(
|
||||
# credentials,
|
||||
# input_data.participant_id,
|
||||
# input_data.dm_conversation_id,
|
||||
# input_data.text,
|
||||
# input_data.media_id
|
||||
# )
|
||||
# yield "dm_event_id", dm_event_id
|
||||
# yield "dm_conversation_id", dm_conversation_id
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", handle_tweepy_exception(e)
|
||||
|
||||
# class TwitterCreateDMConversationBlock(Block):
|
||||
# """
|
||||
# Creates a new group direct message conversation on Twitter
|
||||
# """
|
||||
|
||||
# class Input(BlockSchema):
|
||||
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
# ["offline.access", "dm.write","dm.read","tweet.read","user.read"]
|
||||
# )
|
||||
|
||||
# participant_ids: list[str] = SchemaField(
|
||||
# description="Array of User IDs to create conversation with (max 50)",
|
||||
# placeholder="Enter participant user IDs",
|
||||
# default=[],
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# text: str = SchemaField(
|
||||
# description="Text of the Direct Message (up to 10,000 characters)",
|
||||
# placeholder="Enter message text",
|
||||
# default="",
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# media_id: str = SchemaField(
|
||||
# description="Media ID to attach to the message",
|
||||
# placeholder="Enter media ID",
|
||||
# default="",
|
||||
# advanced=False
|
||||
# )
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# dm_event_id: str = SchemaField(description="ID of the sent direct message")
|
||||
# dm_conversation_id: str = SchemaField(description="ID of the conversation")
|
||||
# error: str = SchemaField(description="Error message if sending failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="ec11cabc-a62e-11ef-8c0e-3fe37ba2ec92",
|
||||
# description="This block creates a new group DM conversation with specified Twitter users.",
|
||||
# categories={BlockCategory.SOCIAL},
|
||||
# input_schema=TwitterCreateDMConversationBlock.Input,
|
||||
# output_schema=TwitterCreateDMConversationBlock.Output,
|
||||
# test_input={
|
||||
# "participant_ids": ["783214", "2244994945"],
|
||||
# "text": "Hello from Twitter API",
|
||||
# "media_id": "",
|
||||
# "credentials": TEST_CREDENTIALS_INPUT
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("dm_event_id", "0987654321"),
|
||||
# ("dm_conversation_id", "1234567890"),
|
||||
# ("error", "")
|
||||
# ],
|
||||
# test_mock={
|
||||
# "create_dm_conversation": lambda *args, **kwargs: (
|
||||
# "0987654321",
|
||||
# "1234567890"
|
||||
# )
|
||||
# },
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def create_dm_conversation(
|
||||
# credentials: TwitterCredentials,
|
||||
# participant_ids: list[str],
|
||||
# text: str,
|
||||
# media_id: str
|
||||
# ):
|
||||
# try:
|
||||
# client = tweepy.Client(
|
||||
# bearer_token=credentials.access_token.get_secret_value()
|
||||
# )
|
||||
|
||||
# response = cast(
|
||||
# Response,
|
||||
# client.create_direct_message_conversation(
|
||||
# participant_ids=participant_ids,
|
||||
# text=None if text == "" else text,
|
||||
# media_id=None if media_id == "" else media_id,
|
||||
# user_auth=False
|
||||
# )
|
||||
# )
|
||||
|
||||
# if not response.data:
|
||||
# raise Exception("Failed to create DM conversation")
|
||||
|
||||
# return response.data["dm_event_id"], response.data["dm_conversation_id"]
|
||||
|
||||
# except tweepy.TweepyException:
|
||||
# raise
|
||||
# except Exception as e:
|
||||
# print(f"Unexpected error: {str(e)}")
|
||||
# raise
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TwitterCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# dm_event_id, dm_conversation_id = self.create_dm_conversation(
|
||||
# credentials,
|
||||
# input_data.participant_ids,
|
||||
# input_data.text,
|
||||
# input_data.media_id
|
||||
# )
|
||||
# yield "dm_event_id", dm_event_id
|
||||
# yield "dm_conversation_id", dm_conversation_id
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,470 +0,0 @@
|
||||
# from typing import cast
|
||||
import tweepy
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
|
||||
# from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
# from backend.blocks.twitter._types import TweetFields, TweetUserFields, UserExpansionInputs, UserExpansions
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
# from tweepy.client import Response
|
||||
|
||||
|
||||
class TwitterUnfollowListBlock(Block):
|
||||
"""
|
||||
Unfollows a Twitter list for the authenticated user
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["follows.write", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to unfollow",
|
||||
placeholder="Enter list ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the unfollow was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1f43310a-a62f-11ef-8276-2b06a1bbae1a",
|
||||
description="This block unfollows a specified Twitter list for the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnfollowListBlock.Input,
|
||||
output_schema=TwitterUnfollowListBlock.Output,
|
||||
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unfollow_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unfollow_list(credentials: TwitterCredentials, list_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unfollow_list(list_id=list_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unfollow_list(credentials, input_data.list_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterFollowListBlock(Block):
|
||||
"""
|
||||
Follows a Twitter list for the authenticated user
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "list.write", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to follow",
|
||||
placeholder="Enter list ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the follow was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="03d8acf6-a62f-11ef-b17f-b72b04a09e79",
|
||||
description="This block follows a specified Twitter list for the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterFollowListBlock.Input,
|
||||
output_schema=TwitterFollowListBlock.Output,
|
||||
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"follow_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def follow_list(credentials: TwitterCredentials, list_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.follow_list(list_id=list_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.follow_list(credentials, input_data.list_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
# Enterprise Level [Need to do Manual testing], There is a high possibility that we might get error in this
|
||||
# Needs Type Input in this
|
||||
|
||||
# class TwitterListGetFollowersBlock(Block):
|
||||
# """
|
||||
# Gets followers of a specified Twitter list
|
||||
# """
|
||||
|
||||
# class Input(UserExpansionInputs):
|
||||
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
# ["tweet.read","users.read", "list.read", "offline.access"]
|
||||
# )
|
||||
|
||||
# list_id: str = SchemaField(
|
||||
# description="The ID of the List to get followers for",
|
||||
# placeholder="Enter list ID",
|
||||
# required=True
|
||||
# )
|
||||
|
||||
# max_results: int = SchemaField(
|
||||
# description="Max number of results per page (1-100)",
|
||||
# placeholder="Enter max results",
|
||||
# default=10,
|
||||
# advanced=True,
|
||||
# )
|
||||
|
||||
# pagination_token: str = SchemaField(
|
||||
# description="Token for pagination",
|
||||
# placeholder="Enter pagination token",
|
||||
# default="",
|
||||
# advanced=True,
|
||||
# )
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# user_ids: list[str] = SchemaField(description="List of user IDs of followers")
|
||||
# usernames: list[str] = SchemaField(description="List of usernames of followers")
|
||||
# next_token: str = SchemaField(description="Token for next page of results")
|
||||
# data: list[dict] = SchemaField(description="Complete follower data")
|
||||
# included: dict = SchemaField(description="Additional data requested via expansions")
|
||||
# meta: dict = SchemaField(description="Metadata about the response")
|
||||
# error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="16b289b4-a62f-11ef-95d4-bb29b849eb99",
|
||||
# description="This block retrieves followers of a specified Twitter list.",
|
||||
# categories={BlockCategory.SOCIAL},
|
||||
# input_schema=TwitterListGetFollowersBlock.Input,
|
||||
# output_schema=TwitterListGetFollowersBlock.Output,
|
||||
# test_input={
|
||||
# "list_id": "123456789",
|
||||
# "max_results": 10,
|
||||
# "pagination_token": None,
|
||||
# "credentials": TEST_CREDENTIALS_INPUT,
|
||||
# "expansions": [],
|
||||
# "tweet_fields": [],
|
||||
# "user_fields": []
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("user_ids", ["2244994945"]),
|
||||
# ("usernames", ["testuser"]),
|
||||
# ("next_token", None),
|
||||
# ("data", {"followers": [{"id": "2244994945", "username": "testuser"}]}),
|
||||
# ("included", {}),
|
||||
# ("meta", {}),
|
||||
# ("error", "")
|
||||
# ],
|
||||
# test_mock={
|
||||
# "get_list_followers": lambda *args, **kwargs: ({
|
||||
# "followers": [{"id": "2244994945", "username": "testuser"}]
|
||||
# }, {}, {}, ["2244994945"], ["testuser"], None)
|
||||
# }
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def get_list_followers(
|
||||
# credentials: TwitterCredentials,
|
||||
# list_id: str,
|
||||
# max_results: int,
|
||||
# pagination_token: str,
|
||||
# expansions: list[UserExpansions],
|
||||
# tweet_fields: list[TweetFields],
|
||||
# user_fields: list[TweetUserFields]
|
||||
# ):
|
||||
# try:
|
||||
# client = tweepy.Client(
|
||||
# bearer_token=credentials.access_token.get_secret_value(),
|
||||
# )
|
||||
|
||||
# params = {
|
||||
# "id": list_id,
|
||||
# "max_results": max_results,
|
||||
# "pagination_token": None if pagination_token == "" else pagination_token,
|
||||
# "user_auth": False
|
||||
# }
|
||||
|
||||
# params = (UserExpansionsBuilder(params)
|
||||
# .add_expansions(expansions)
|
||||
# .add_tweet_fields(tweet_fields)
|
||||
# .add_user_fields(user_fields)
|
||||
# .build())
|
||||
|
||||
# response = cast(
|
||||
# Response,
|
||||
# client.get_list_followers(**params)
|
||||
# )
|
||||
|
||||
# meta = {}
|
||||
# user_ids = []
|
||||
# usernames = []
|
||||
# next_token = None
|
||||
|
||||
# if response.meta:
|
||||
# meta = response.meta
|
||||
# next_token = meta.get("next_token")
|
||||
|
||||
# included = IncludesSerializer.serialize(response.includes)
|
||||
# data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
# if response.data:
|
||||
# user_ids = [str(item.id) for item in response.data]
|
||||
# usernames = [item.username for item in response.data]
|
||||
|
||||
# return data, included, meta, user_ids, usernames, next_token
|
||||
|
||||
# raise Exception("No followers found")
|
||||
|
||||
# except tweepy.TweepyException:
|
||||
# raise
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TwitterCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# followers_data, included, meta, user_ids, usernames, next_token = self.get_list_followers(
|
||||
# credentials,
|
||||
# input_data.list_id,
|
||||
# input_data.max_results,
|
||||
# input_data.pagination_token,
|
||||
# input_data.expansions,
|
||||
# input_data.tweet_fields,
|
||||
# input_data.user_fields
|
||||
# )
|
||||
|
||||
# if user_ids:
|
||||
# yield "user_ids", user_ids
|
||||
# if usernames:
|
||||
# yield "usernames", usernames
|
||||
# if next_token:
|
||||
# yield "next_token", next_token
|
||||
# if followers_data:
|
||||
# yield "data", followers_data
|
||||
# if included:
|
||||
# yield "included", included
|
||||
# if meta:
|
||||
# yield "meta", meta
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", handle_tweepy_exception(e)
|
||||
|
||||
# class TwitterGetFollowedListsBlock(Block):
|
||||
# """
|
||||
# Gets lists followed by a specified Twitter user
|
||||
# """
|
||||
|
||||
# class Input(UserExpansionInputs):
|
||||
# credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
# ["follows.read", "users.read", "list.read", "offline.access"]
|
||||
# )
|
||||
|
||||
# user_id: str = SchemaField(
|
||||
# description="The user ID whose followed Lists to retrieve",
|
||||
# placeholder="Enter user ID",
|
||||
# required=True
|
||||
# )
|
||||
|
||||
# max_results: int = SchemaField(
|
||||
# description="Max number of results per page (1-100)",
|
||||
# placeholder="Enter max results",
|
||||
# default=10,
|
||||
# advanced=True,
|
||||
# )
|
||||
|
||||
# pagination_token: str = SchemaField(
|
||||
# description="Token for pagination",
|
||||
# placeholder="Enter pagination token",
|
||||
# default="",
|
||||
# advanced=True,
|
||||
# )
|
||||
|
||||
# class Output(BlockSchema):
|
||||
# list_ids: list[str] = SchemaField(description="List of list IDs")
|
||||
# list_names: list[str] = SchemaField(description="List of list names")
|
||||
# data: list[dict] = SchemaField(description="Complete list data")
|
||||
# includes: dict = SchemaField(description="Additional data requested via expansions")
|
||||
# meta: dict = SchemaField(description="Metadata about the response")
|
||||
# next_token: str = SchemaField(description="Token for next page of results")
|
||||
# error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
# def __init__(self):
|
||||
# super().__init__(
|
||||
# id="0e18bbfc-a62f-11ef-94fa-1f1e174b809e",
|
||||
# description="This block retrieves all Lists a specified user follows.",
|
||||
# categories={BlockCategory.SOCIAL},
|
||||
# input_schema=TwitterGetFollowedListsBlock.Input,
|
||||
# output_schema=TwitterGetFollowedListsBlock.Output,
|
||||
# test_input={
|
||||
# "user_id": "123456789",
|
||||
# "max_results": 10,
|
||||
# "pagination_token": None,
|
||||
# "credentials": TEST_CREDENTIALS_INPUT,
|
||||
# "expansions": [],
|
||||
# "tweet_fields": [],
|
||||
# "user_fields": []
|
||||
# },
|
||||
# test_credentials=TEST_CREDENTIALS,
|
||||
# test_output=[
|
||||
# ("list_ids", ["12345"]),
|
||||
# ("list_names", ["Test List"]),
|
||||
# ("data", {"followed_lists": [{"id": "12345", "name": "Test List"}]}),
|
||||
# ("includes", {}),
|
||||
# ("meta", {}),
|
||||
# ("next_token", None),
|
||||
# ("error", "")
|
||||
# ],
|
||||
# test_mock={
|
||||
# "get_followed_lists": lambda *args, **kwargs: ({
|
||||
# "followed_lists": [{"id": "12345", "name": "Test List"}]
|
||||
# }, {}, {}, ["12345"], ["Test List"], None)
|
||||
# }
|
||||
# )
|
||||
|
||||
# @staticmethod
|
||||
# def get_followed_lists(
|
||||
# credentials: TwitterCredentials,
|
||||
# user_id: str,
|
||||
# max_results: int,
|
||||
# pagination_token: str,
|
||||
# expansions: list[UserExpansions],
|
||||
# tweet_fields: list[TweetFields],
|
||||
# user_fields: list[TweetUserFields]
|
||||
# ):
|
||||
# try:
|
||||
# client = tweepy.Client(
|
||||
# bearer_token=credentials.access_token.get_secret_value(),
|
||||
# )
|
||||
|
||||
# params = {
|
||||
# "id": user_id,
|
||||
# "max_results": max_results,
|
||||
# "pagination_token": None if pagination_token == "" else pagination_token,
|
||||
# "user_auth": False
|
||||
# }
|
||||
|
||||
# params = (UserExpansionsBuilder(params)
|
||||
# .add_expansions(expansions)
|
||||
# .add_tweet_fields(tweet_fields)
|
||||
# .add_user_fields(user_fields)
|
||||
# .build())
|
||||
|
||||
# response = cast(
|
||||
# Response,
|
||||
# client.get_followed_lists(**params)
|
||||
# )
|
||||
|
||||
# meta = {}
|
||||
# list_ids = []
|
||||
# list_names = []
|
||||
# next_token = None
|
||||
|
||||
# if response.meta:
|
||||
# meta = response.meta
|
||||
# next_token = meta.get("next_token")
|
||||
|
||||
# included = IncludesSerializer.serialize(response.includes)
|
||||
# data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
# if response.data:
|
||||
# list_ids = [str(item.id) for item in response.data]
|
||||
# list_names = [item.name for item in response.data]
|
||||
|
||||
# return data, included, meta, list_ids, list_names, next_token
|
||||
|
||||
# raise Exception("No followed lists found")
|
||||
|
||||
# except tweepy.TweepyException:
|
||||
# raise
|
||||
|
||||
# def run(
|
||||
# self,
|
||||
# input_data: Input,
|
||||
# *,
|
||||
# credentials: TwitterCredentials,
|
||||
# **kwargs,
|
||||
# ) -> BlockOutput:
|
||||
# try:
|
||||
# lists_data, included, meta, list_ids, list_names, next_token = self.get_followed_lists(
|
||||
# credentials,
|
||||
# input_data.user_id,
|
||||
# input_data.max_results,
|
||||
# input_data.pagination_token,
|
||||
# input_data.expansions,
|
||||
# input_data.tweet_fields,
|
||||
# input_data.user_fields
|
||||
# )
|
||||
|
||||
# if list_ids:
|
||||
# yield "list_ids", list_ids
|
||||
# if list_names:
|
||||
# yield "list_names", list_names
|
||||
# if next_token:
|
||||
# yield "next_token", next_token
|
||||
# if lists_data:
|
||||
# yield "data", lists_data
|
||||
# if included:
|
||||
# yield "includes", included
|
||||
# if meta:
|
||||
# yield "meta", meta
|
||||
|
||||
# except Exception as e:
|
||||
# yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,348 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import ListExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ListExpansionInputs,
|
||||
ListExpansionsFilter,
|
||||
ListFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterGetListBlock(Block):
|
||||
"""
|
||||
Gets information about a Twitter List specified by ID
|
||||
"""
|
||||
|
||||
class Input(ListExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to lookup",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
id: str = SchemaField(description="ID of the Twitter List")
|
||||
name: str = SchemaField(description="Name of the Twitter List")
|
||||
owner_id: str = SchemaField(description="ID of the List owner")
|
||||
owner_username: str = SchemaField(description="Username of the List owner")
|
||||
|
||||
# Complete outputs
|
||||
data: dict = SchemaField(description="Complete list data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about the response")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="34ebc80a-a62f-11ef-9c2a-3fcab6c07079",
|
||||
description="This block retrieves information about a specified Twitter List.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetListBlock.Input,
|
||||
output_schema=TwitterGetListBlock.Output,
|
||||
test_input={
|
||||
"list_id": "84839422",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"list_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "84839422"),
|
||||
("name", "Official Twitter Accounts"),
|
||||
("owner_id", "2244994945"),
|
||||
("owner_username", "TwitterAPI"),
|
||||
("data", {"id": "84839422", "name": "Official Twitter Accounts"}),
|
||||
],
|
||||
test_mock={
|
||||
"get_list": lambda *args, **kwargs: (
|
||||
{"id": "84839422", "name": "Official Twitter Accounts"},
|
||||
{},
|
||||
{},
|
||||
"2244994945",
|
||||
"TwitterAPI",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_list(
|
||||
credentials: TwitterCredentials,
|
||||
list_id: str,
|
||||
expansions: ListExpansionsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
list_fields: ListFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {"id": list_id, "user_auth": False}
|
||||
|
||||
params = (
|
||||
ListExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_user_fields(user_fields)
|
||||
.add_list_fields(list_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_list(**params))
|
||||
|
||||
meta = {}
|
||||
owner_id = ""
|
||||
owner_username = ""
|
||||
included = {}
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if "users" in included:
|
||||
owner_id = str(included["users"][0]["id"])
|
||||
owner_username = included["users"][0]["username"]
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
|
||||
if response.data:
|
||||
data_dict = ResponseDataSerializer.serialize_dict(response.data)
|
||||
return data_dict, included, meta, owner_id, owner_username
|
||||
|
||||
raise Exception("List not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
list_data, included, meta, owner_id, owner_username = self.get_list(
|
||||
credentials,
|
||||
input_data.list_id,
|
||||
input_data.expansions,
|
||||
input_data.user_fields,
|
||||
input_data.list_fields,
|
||||
)
|
||||
|
||||
yield "id", str(list_data["id"])
|
||||
yield "name", list_data["name"]
|
||||
if owner_id:
|
||||
yield "owner_id", owner_id
|
||||
if owner_username:
|
||||
yield "owner_username", owner_username
|
||||
yield "data", {"id": list_data["id"], "name": list_data["name"]}
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetOwnedListsBlock(Block):
|
||||
"""
|
||||
Gets all Lists owned by the specified user
|
||||
"""
|
||||
|
||||
class Input(ListExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "list.read", "offline.access"]
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The user ID whose owned Lists to retrieve",
|
||||
placeholder="Enter user ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results per page (1-100)",
|
||||
placeholder="Enter max results (default 100)",
|
||||
advanced=True,
|
||||
default=10,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination",
|
||||
placeholder="Enter pagination token",
|
||||
advanced=True,
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
list_ids: list[str] = SchemaField(description="List ids of the owned lists")
|
||||
list_names: list[str] = SchemaField(description="List names of the owned lists")
|
||||
next_token: str = SchemaField(description="Token for next page of results")
|
||||
|
||||
# Complete outputs
|
||||
data: list[dict] = SchemaField(description="Complete owned lists data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about the response")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="2b6bdb26-a62f-11ef-a9ce-ff89c2568726",
|
||||
description="This block retrieves all Lists owned by a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetOwnedListsBlock.Input,
|
||||
output_schema=TwitterGetOwnedListsBlock.Output,
|
||||
test_input={
|
||||
"user_id": "2244994945",
|
||||
"max_results": 10,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"list_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("list_ids", ["84839422"]),
|
||||
("list_names", ["Official Twitter Accounts"]),
|
||||
("data", [{"id": "84839422", "name": "Official Twitter Accounts"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_owned_lists": lambda *args, **kwargs: (
|
||||
[{"id": "84839422", "name": "Official Twitter Accounts"}],
|
||||
{},
|
||||
{},
|
||||
["84839422"],
|
||||
["Official Twitter Accounts"],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_owned_lists(
|
||||
credentials: TwitterCredentials,
|
||||
user_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ListExpansionsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
list_fields: ListFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
ListExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_user_fields(user_fields)
|
||||
.add_list_fields(list_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_owned_lists(**params))
|
||||
|
||||
meta = {}
|
||||
included = {}
|
||||
list_ids = []
|
||||
list_names = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
list_ids = [
|
||||
str(item.id) for item in response.data if hasattr(item, "id")
|
||||
]
|
||||
list_names = [
|
||||
item.name for item in response.data if hasattr(item, "name")
|
||||
]
|
||||
|
||||
return data, included, meta, list_ids, list_names, next_token
|
||||
|
||||
raise Exception("User have no owned list")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
list_data, included, meta, list_ids, list_names, next_token = (
|
||||
self.get_owned_lists(
|
||||
credentials,
|
||||
input_data.user_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.user_fields,
|
||||
input_data.list_fields,
|
||||
)
|
||||
)
|
||||
|
||||
if list_ids:
|
||||
yield "list_ids", list_ids
|
||||
if list_names:
|
||||
yield "list_names", list_names
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if list_data:
|
||||
yield "data", list_data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,527 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import (
|
||||
ListExpansionsBuilder,
|
||||
UserExpansionsBuilder,
|
||||
)
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ListExpansionInputs,
|
||||
ListExpansionsFilter,
|
||||
ListFieldsFilter,
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterRemoveListMemberBlock(Block):
|
||||
"""
|
||||
Removes a member from a Twitter List that the authenticated user owns
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "users.read", "tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to remove the member from",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user to remove from the List",
|
||||
placeholder="Enter user ID to remove",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the member was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the removal failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5a3d1320-a62f-11ef-b7ce-a79e7656bcb0",
|
||||
description="This block removes a specified user from a Twitter List owned by the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterRemoveListMemberBlock.Input,
|
||||
output_schema=TwitterRemoveListMemberBlock.Output,
|
||||
test_input={
|
||||
"list_id": "123456789",
|
||||
"user_id": "987654321",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"remove_list_member": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def remove_list_member(credentials: TwitterCredentials, list_id: str, user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
client.remove_list_member(id=list_id, user_id=user_id, user_auth=False)
|
||||
return True
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.remove_list_member(
|
||||
credentials, input_data.list_id, input_data.user_id
|
||||
)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterAddListMemberBlock(Block):
|
||||
"""
|
||||
Adds a member to a Twitter List that the authenticated user owns
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "users.read", "tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to add the member to",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user to add to the List",
|
||||
placeholder="Enter user ID to add",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the member was successfully added"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the addition failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3ee8284e-a62f-11ef-84e4-8f6e2cbf0ddb",
|
||||
description="This block adds a specified user to a Twitter List owned by the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterAddListMemberBlock.Input,
|
||||
output_schema=TwitterAddListMemberBlock.Output,
|
||||
test_input={
|
||||
"list_id": "123456789",
|
||||
"user_id": "987654321",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"add_list_member": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def add_list_member(credentials: TwitterCredentials, list_id: str, user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
client.add_list_member(id=list_id, user_id=user_id, user_auth=False)
|
||||
return True
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.add_list_member(
|
||||
credentials, input_data.list_id, input_data.user_id
|
||||
)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetListMembersBlock(Block):
|
||||
"""
|
||||
Gets the members of a specified Twitter List
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to get members from",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results per page (1-100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ids: list[str] = SchemaField(description="List of member user IDs")
|
||||
usernames: list[str] = SchemaField(description="List of member usernames")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
data: list[dict] = SchemaField(
|
||||
description="Complete user data for list members"
|
||||
)
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4dba046e-a62f-11ef-b69a-87240c84b4c7",
|
||||
description="This block retrieves the members of a specified Twitter List.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetListMembersBlock.Input,
|
||||
output_schema=TwitterGetListMembersBlock.Output,
|
||||
test_input={
|
||||
"list_id": "123456789",
|
||||
"max_results": 2,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["12345", "67890"]),
|
||||
("usernames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "12345", "username": "testuser1"},
|
||||
{"id": "67890", "username": "testuser2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_list_members": lambda *args, **kwargs: (
|
||||
["12345", "67890"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "12345", "username": "testuser1"},
|
||||
{"id": "67890", "username": "testuser2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_list_members(
|
||||
credentials: TwitterCredentials,
|
||||
list_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": list_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_list_members(**params))
|
||||
|
||||
meta = {}
|
||||
included = {}
|
||||
next_token = None
|
||||
user_ids = []
|
||||
usernames = []
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
user_ids = [str(user.id) for user in response.data]
|
||||
usernames = [user.username for user in response.data]
|
||||
return user_ids, usernames, data, included, meta, next_token
|
||||
|
||||
raise Exception("List members not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, usernames, data, included, meta, next_token = self.get_list_members(
|
||||
credentials,
|
||||
input_data.list_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetListMembershipsBlock(Block):
|
||||
"""
|
||||
Gets all Lists that a specified user is a member of
|
||||
"""
|
||||
|
||||
class Input(ListExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.read", "offline.access"]
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="The ID of the user whose List memberships to retrieve",
|
||||
placeholder="Enter user ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results per page (1-100)",
|
||||
placeholder="Enter max results",
|
||||
advanced=True,
|
||||
default=10,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination of results",
|
||||
placeholder="Enter pagination token",
|
||||
advanced=True,
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list_ids: list[str] = SchemaField(description="List of list IDs")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
data: list[dict] = SchemaField(description="List membership data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about pagination")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="46e6429c-a62f-11ef-81c0-2b55bc7823ba",
|
||||
description="This block retrieves all Lists that a specified user is a member of.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetListMembershipsBlock.Input,
|
||||
output_schema=TwitterGetListMembershipsBlock.Output,
|
||||
test_input={
|
||||
"user_id": "123456789",
|
||||
"max_results": 1,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"list_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("list_ids", ["84839422"]),
|
||||
("data", [{"id": "84839422"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_list_memberships": lambda *args, **kwargs: (
|
||||
[{"id": "84839422"}],
|
||||
{},
|
||||
{},
|
||||
["84839422"],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_list_memberships(
|
||||
credentials: TwitterCredentials,
|
||||
user_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ListExpansionsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
list_fields: ListFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
ListExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_user_fields(user_fields)
|
||||
.add_list_fields(list_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_list_memberships(**params))
|
||||
|
||||
meta = {}
|
||||
included = {}
|
||||
next_token = None
|
||||
list_ids = []
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
list_ids = [str(lst.id) for lst in response.data]
|
||||
return data, included, meta, list_ids, next_token
|
||||
|
||||
raise Exception("List memberships not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, meta, list_ids, next_token = self.get_list_memberships(
|
||||
credentials,
|
||||
input_data.user_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.user_fields,
|
||||
input_data.list_fields,
|
||||
)
|
||||
|
||||
if list_ids:
|
||||
yield "list_ids", list_ids
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,217 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import TweetExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterGetListTweetsBlock(Block):
|
||||
"""
|
||||
Gets tweets from a specified Twitter list
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List whose Tweets you would like to retrieve",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results per page (1-100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for paginating through results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
tweet_ids: list[str] = SchemaField(description="List of tweet IDs")
|
||||
texts: list[str] = SchemaField(description="List of tweet texts")
|
||||
next_token: str = SchemaField(description="Token for next page of results")
|
||||
|
||||
# Complete outputs
|
||||
data: list[dict] = SchemaField(description="Complete list tweets data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Response metadata including pagination tokens"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="6657edb0-a62f-11ef-8c10-0326d832467d",
|
||||
description="This block retrieves tweets from a specified Twitter list.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetListTweetsBlock.Input,
|
||||
output_schema=TwitterGetListTweetsBlock.Output,
|
||||
test_input={
|
||||
"list_id": "84839422",
|
||||
"max_results": 1,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("tweet_ids", ["1234567890"]),
|
||||
("texts", ["Test tweet"]),
|
||||
("data", [{"id": "1234567890", "text": "Test tweet"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_list_tweets": lambda *args, **kwargs: (
|
||||
[{"id": "1234567890", "text": "Test tweet"}],
|
||||
{},
|
||||
{},
|
||||
["1234567890"],
|
||||
["Test tweet"],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_list_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
list_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": list_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_list_tweets(**params))
|
||||
|
||||
meta = {}
|
||||
included = {}
|
||||
tweet_ids = []
|
||||
texts = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
tweet_ids = [str(item.id) for item in response.data]
|
||||
texts = [item.text for item in response.data]
|
||||
|
||||
return data, included, meta, tweet_ids, texts, next_token
|
||||
|
||||
raise Exception("No tweets found in this list")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
list_data, included, meta, tweet_ids, texts, next_token = (
|
||||
self.get_list_tweets(
|
||||
credentials,
|
||||
input_data.list_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
|
||||
if tweet_ids:
|
||||
yield "tweet_ids", tweet_ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if list_data:
|
||||
yield "data", list_data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,278 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterDeleteListBlock(Block):
|
||||
"""
|
||||
Deletes a Twitter List owned by the authenticated user
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to be deleted",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the deletion was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="843c6892-a62f-11ef-a5c8-b71239a78d3b",
|
||||
description="This block deletes a specified Twitter List owned by the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterDeleteListBlock.Input,
|
||||
output_schema=TwitterDeleteListBlock.Output,
|
||||
test_input={"list_id": "1234567890", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"delete_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_list(credentials: TwitterCredentials, list_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.delete_list(id=list_id, user_auth=False)
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.delete_list(credentials, input_data.list_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterUpdateListBlock(Block):
|
||||
"""
|
||||
Updates a Twitter List owned by the authenticated user
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to be updated",
|
||||
placeholder="Enter list ID",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
name: str | None = SchemaField(
|
||||
description="New name for the List",
|
||||
placeholder="Enter list name",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
description: str | None = SchemaField(
|
||||
description="New description for the List",
|
||||
placeholder="Enter list description",
|
||||
default="",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the update was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7d12630a-a62f-11ef-90c9-8f5a996612c3",
|
||||
description="This block updates a specified Twitter List owned by the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUpdateListBlock.Input,
|
||||
output_schema=TwitterUpdateListBlock.Output,
|
||||
test_input={
|
||||
"list_id": "1234567890",
|
||||
"name": "Updated List Name",
|
||||
"description": "Updated List Description",
|
||||
"private": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"update_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_list(
|
||||
credentials: TwitterCredentials,
|
||||
list_id: str,
|
||||
name: str | None,
|
||||
description: str | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.update_list(
|
||||
id=list_id,
|
||||
name=None if name == "" else name,
|
||||
description=None if description == "" else description,
|
||||
user_auth=False,
|
||||
)
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.update_list(
|
||||
credentials, input_data.list_id, input_data.name, input_data.description
|
||||
)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterCreateListBlock(Block):
|
||||
"""
|
||||
Creates a Twitter List owned by the authenticated user
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "offline.access"]
|
||||
)
|
||||
|
||||
name: str = SchemaField(
|
||||
description="The name of the List to be created",
|
||||
placeholder="Enter list name",
|
||||
advanced=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
description: str | None = SchemaField(
|
||||
description="Description of the List",
|
||||
placeholder="Enter list description",
|
||||
advanced=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
private: bool = SchemaField(
|
||||
description="Whether the List should be private",
|
||||
advanced=False,
|
||||
default=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
url: str = SchemaField(description="URL of the created list")
|
||||
list_id: str = SchemaField(description="ID of the created list")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="724148ba-a62f-11ef-89ba-5349b813ef5f",
|
||||
description="This block creates a new Twitter List for the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterCreateListBlock.Input,
|
||||
output_schema=TwitterCreateListBlock.Output,
|
||||
test_input={
|
||||
"name": "New List Name",
|
||||
"description": "New List Description",
|
||||
"private": True,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("list_id", "1234567890"),
|
||||
("url", "https://twitter.com/i/lists/1234567890"),
|
||||
],
|
||||
test_mock={"create_list": lambda *args, **kwargs: ("1234567890")},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_list(
|
||||
credentials: TwitterCredentials,
|
||||
name: str,
|
||||
description: str | None,
|
||||
private: bool,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
response = cast(
|
||||
Response,
|
||||
client.create_list(
|
||||
name=None if name == "" else name,
|
||||
description=None if description == "" else description,
|
||||
private=private,
|
||||
user_auth=False,
|
||||
),
|
||||
)
|
||||
|
||||
list_id = str(response.data["id"])
|
||||
|
||||
return list_id
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
list_id = self.create_list(
|
||||
credentials, input_data.name, input_data.description, input_data.private
|
||||
)
|
||||
yield "list_id", list_id
|
||||
yield "url", f"https://twitter.com/i/lists/{list_id}"
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,285 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import ListExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ListExpansionInputs,
|
||||
ListExpansionsFilter,
|
||||
ListFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterUnpinListBlock(Block):
|
||||
"""
|
||||
Enables the authenticated user to unpin a List.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "users.read", "tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to unpin",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the unpin was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="a099c034-a62f-11ef-9622-47d0ceb73555",
|
||||
description="This block allows the authenticated user to unpin a specified List.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnpinListBlock.Input,
|
||||
output_schema=TwitterUnpinListBlock.Output,
|
||||
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"unpin_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unpin_list(credentials: TwitterCredentials, list_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unpin_list(list_id=list_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unpin_list(credentials, input_data.list_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterPinListBlock(Block):
|
||||
"""
|
||||
Enables the authenticated user to pin a List.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["list.write", "users.read", "tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
list_id: str = SchemaField(
|
||||
description="The ID of the List to pin",
|
||||
placeholder="Enter list ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the pin was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8ec16e48-a62f-11ef-9f35-f3d6de43a802",
|
||||
description="This block allows the authenticated user to pin a specified List.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterPinListBlock.Input,
|
||||
output_schema=TwitterPinListBlock.Output,
|
||||
test_input={"list_id": "123456789", "credentials": TEST_CREDENTIALS_INPUT},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"pin_list": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pin_list(credentials: TwitterCredentials, list_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.pin_list(list_id=list_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.pin_list(credentials, input_data.list_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetPinnedListsBlock(Block):
|
||||
"""
|
||||
Returns the Lists pinned by the authenticated user.
|
||||
"""
|
||||
|
||||
class Input(ListExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["lists.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
list_ids: list[str] = SchemaField(description="List IDs of the pinned lists")
|
||||
list_names: list[str] = SchemaField(
|
||||
description="List names of the pinned lists"
|
||||
)
|
||||
|
||||
data: list[dict] = SchemaField(
|
||||
description="Response data containing pinned lists"
|
||||
)
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about the response")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="97e03aae-a62f-11ef-bc53-5b89cb02888f",
|
||||
description="This block returns the Lists pinned by the authenticated user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetPinnedListsBlock.Input,
|
||||
output_schema=TwitterGetPinnedListsBlock.Output,
|
||||
test_input={
|
||||
"expansions": None,
|
||||
"list_fields": None,
|
||||
"user_fields": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("list_ids", ["84839422"]),
|
||||
("list_names", ["Twitter List"]),
|
||||
("data", [{"id": "84839422", "name": "Twitter List"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_pinned_lists": lambda *args, **kwargs: (
|
||||
[{"id": "84839422", "name": "Twitter List"}],
|
||||
{},
|
||||
{},
|
||||
["84839422"],
|
||||
["Twitter List"],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_pinned_lists(
|
||||
credentials: TwitterCredentials,
|
||||
expansions: ListExpansionsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
list_fields: ListFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {"user_auth": False}
|
||||
|
||||
params = (
|
||||
ListExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_user_fields(user_fields)
|
||||
.add_list_fields(list_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_pinned_lists(**params))
|
||||
|
||||
meta = {}
|
||||
included = {}
|
||||
list_ids = []
|
||||
list_names = []
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
|
||||
if response.includes:
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
list_ids = [str(item.id) for item in response.data]
|
||||
list_names = [item.name for item in response.data]
|
||||
return data, included, meta, list_ids, list_names
|
||||
|
||||
raise Exception("Lists not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
list_data, included, meta, list_ids, list_names = self.get_pinned_lists(
|
||||
credentials,
|
||||
input_data.expansions,
|
||||
input_data.user_fields,
|
||||
input_data.list_fields,
|
||||
)
|
||||
|
||||
if list_ids:
|
||||
yield "list_ids", list_ids
|
||||
if list_names:
|
||||
yield "list_names", list_names
|
||||
if list_data:
|
||||
yield "data", list_data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,195 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import SpaceExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
SpaceExpansionInputs,
|
||||
SpaceExpansionsFilter,
|
||||
SpaceFieldsFilter,
|
||||
SpaceStatesFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterSearchSpacesBlock(Block):
|
||||
"""
|
||||
Returns live or scheduled Spaces matching specified search terms [for a week only]
|
||||
"""
|
||||
|
||||
class Input(SpaceExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["spaces.read", "users.read", "tweet.read", "offline.access"]
|
||||
)
|
||||
|
||||
query: str = SchemaField(
|
||||
description="Search term to find in Space titles",
|
||||
placeholder="Enter search query",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
state: SpaceStatesFilter = SchemaField(
|
||||
description="Type of Spaces to return (live, scheduled, or all)",
|
||||
placeholder="Enter state filter",
|
||||
default=SpaceStatesFilter.all,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="List of space IDs")
|
||||
titles: list[str] = SchemaField(description="List of space titles")
|
||||
host_ids: list = SchemaField(description="List of host IDs")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete space data")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="aaefdd48-a62f-11ef-a73c-3f44df63e276",
|
||||
description="This block searches for Twitter Spaces based on specified terms.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterSearchSpacesBlock.Input,
|
||||
output_schema=TwitterSearchSpacesBlock.Output,
|
||||
test_input={
|
||||
"query": "tech",
|
||||
"max_results": 1,
|
||||
"state": "live",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"space_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1234"]),
|
||||
("titles", ["Tech Talk"]),
|
||||
("host_ids", ["5678"]),
|
||||
("data", [{"id": "1234", "title": "Tech Talk", "host_ids": ["5678"]}]),
|
||||
],
|
||||
test_mock={
|
||||
"search_spaces": lambda *args, **kwargs: (
|
||||
[{"id": "1234", "title": "Tech Talk", "host_ids": ["5678"]}],
|
||||
{},
|
||||
{},
|
||||
["1234"],
|
||||
["Tech Talk"],
|
||||
["5678"],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def search_spaces(
|
||||
credentials: TwitterCredentials,
|
||||
query: str,
|
||||
max_results: int | None,
|
||||
state: SpaceStatesFilter,
|
||||
expansions: SpaceExpansionsFilter | None,
|
||||
space_fields: SpaceFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {"query": query, "max_results": max_results, "state": state.value}
|
||||
|
||||
params = (
|
||||
SpaceExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_space_fields(space_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.search_spaces(**params))
|
||||
|
||||
meta = {}
|
||||
next_token = ""
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
if "next_token" in meta:
|
||||
next_token = meta["next_token"]
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
ids = [str(space["id"]) for space in response.data if "id" in space]
|
||||
titles = [space["title"] for space in data if "title" in space]
|
||||
host_ids = [space["host_ids"] for space in data if "host_ids" in space]
|
||||
|
||||
return data, included, meta, ids, titles, host_ids, next_token
|
||||
|
||||
raise Exception("Spaces not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, meta, ids, titles, host_ids, next_token = (
|
||||
self.search_spaces(
|
||||
credentials,
|
||||
input_data.query,
|
||||
input_data.max_results,
|
||||
input_data.state,
|
||||
input_data.expansions,
|
||||
input_data.space_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if titles:
|
||||
yield "titles", titles
|
||||
if host_ids:
|
||||
yield "host_ids", host_ids
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "includes", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,651 +0,0 @@
|
||||
from typing import Literal, Union, cast
|
||||
|
||||
import tweepy
|
||||
from pydantic import BaseModel
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import (
|
||||
SpaceExpansionsBuilder,
|
||||
TweetExpansionsBuilder,
|
||||
UserExpansionsBuilder,
|
||||
)
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
SpaceExpansionInputs,
|
||||
SpaceExpansionsFilter,
|
||||
SpaceFieldsFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class SpaceList(BaseModel):
|
||||
discriminator: Literal["space_list"]
|
||||
space_ids: list[str] = SchemaField(
|
||||
description="List of Space IDs to lookup (up to 100)",
|
||||
placeholder="Enter Space IDs",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
class UserList(BaseModel):
|
||||
discriminator: Literal["user_list"]
|
||||
user_ids: list[str] = SchemaField(
|
||||
description="List of user IDs to lookup their Spaces (up to 100)",
|
||||
placeholder="Enter user IDs",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
class TwitterGetSpacesBlock(Block):
|
||||
"""
|
||||
Gets information about multiple Twitter Spaces specified by Space IDs or creator user IDs
|
||||
"""
|
||||
|
||||
class Input(SpaceExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["spaces.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
identifier: Union[SpaceList, UserList] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Choose whether to lookup spaces by their IDs or by creator user IDs",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
ids: list[str] = SchemaField(description="List of space IDs")
|
||||
titles: list[str] = SchemaField(description="List of space titles")
|
||||
|
||||
# Complete outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete space data")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d75bd7d8-a62f-11ef-b0d8-c7a9496f617f",
|
||||
description="This block retrieves information about multiple Twitter Spaces.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetSpacesBlock.Input,
|
||||
output_schema=TwitterGetSpacesBlock.Output,
|
||||
test_input={
|
||||
"identifier": {
|
||||
"discriminator": "space_list",
|
||||
"space_ids": ["1DXxyRYNejbKM"],
|
||||
},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"space_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1DXxyRYNejbKM"]),
|
||||
("titles", ["Test Space"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{
|
||||
"id": "1DXxyRYNejbKM",
|
||||
"title": "Test Space",
|
||||
"host_id": "1234567",
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_spaces": lambda *args, **kwargs: (
|
||||
[
|
||||
{
|
||||
"id": "1DXxyRYNejbKM",
|
||||
"title": "Test Space",
|
||||
"host_id": "1234567",
|
||||
}
|
||||
],
|
||||
{},
|
||||
["1DXxyRYNejbKM"],
|
||||
["Test Space"],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_spaces(
|
||||
credentials: TwitterCredentials,
|
||||
identifier: Union[SpaceList, UserList],
|
||||
expansions: SpaceExpansionsFilter | None,
|
||||
space_fields: SpaceFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"ids": (
|
||||
identifier.space_ids if isinstance(identifier, SpaceList) else None
|
||||
),
|
||||
"user_ids": (
|
||||
identifier.user_ids if isinstance(identifier, UserList) else None
|
||||
),
|
||||
}
|
||||
|
||||
params = (
|
||||
SpaceExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_space_fields(space_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_spaces(**params))
|
||||
|
||||
ids = []
|
||||
titles = []
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
ids = [space["id"] for space in data if "id" in space]
|
||||
titles = [space["title"] for space in data if "title" in space]
|
||||
|
||||
return data, included, ids, titles
|
||||
|
||||
raise Exception("No spaces found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, ids, titles = self.get_spaces(
|
||||
credentials,
|
||||
input_data.identifier,
|
||||
input_data.expansions,
|
||||
input_data.space_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if titles:
|
||||
yield "titles", titles
|
||||
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "includes", included
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetSpaceByIdBlock(Block):
|
||||
"""
|
||||
Gets information about a single Twitter Space specified by Space ID
|
||||
"""
|
||||
|
||||
class Input(SpaceExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["spaces.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
id: str = SchemaField(description="Space ID")
|
||||
title: str = SchemaField(description="Space title")
|
||||
host_ids: list[str] = SchemaField(description="Host ID")
|
||||
|
||||
# Complete outputs for advanced use
|
||||
data: dict = SchemaField(description="Complete space data")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c79700de-a62f-11ef-ab20-fb32bf9d5a9d",
|
||||
description="This block retrieves information about a single Twitter Space.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetSpaceByIdBlock.Input,
|
||||
output_schema=TwitterGetSpaceByIdBlock.Output,
|
||||
test_input={
|
||||
"space_id": "1DXxyRYNejbKM",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"space_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "1DXxyRYNejbKM"),
|
||||
("title", "Test Space"),
|
||||
("host_ids", ["1234567"]),
|
||||
(
|
||||
"data",
|
||||
{
|
||||
"id": "1DXxyRYNejbKM",
|
||||
"title": "Test Space",
|
||||
"host_ids": ["1234567"],
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_space": lambda *args, **kwargs: (
|
||||
{
|
||||
"id": "1DXxyRYNejbKM",
|
||||
"title": "Test Space",
|
||||
"host_ids": ["1234567"],
|
||||
},
|
||||
{},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_space(
|
||||
credentials: TwitterCredentials,
|
||||
space_id: str,
|
||||
expansions: SpaceExpansionsFilter | None,
|
||||
space_fields: SpaceFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": space_id,
|
||||
}
|
||||
|
||||
params = (
|
||||
SpaceExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_space_fields(space_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_space(**params))
|
||||
|
||||
includes = {}
|
||||
if response.includes:
|
||||
for key, value in response.includes.items():
|
||||
if isinstance(value, list):
|
||||
includes[key] = [
|
||||
item.data if hasattr(item, "data") else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
includes[key] = value.data if hasattr(value, "data") else value
|
||||
|
||||
data = {}
|
||||
if response.data:
|
||||
for key, value in response.data.items():
|
||||
if isinstance(value, list):
|
||||
data[key] = [
|
||||
item.data if hasattr(item, "data") else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
data[key] = value.data if hasattr(value, "data") else value
|
||||
|
||||
return data, includes
|
||||
|
||||
raise Exception("Space not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
space_data, includes = self.get_space(
|
||||
credentials,
|
||||
input_data.space_id,
|
||||
input_data.expansions,
|
||||
input_data.space_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
# Common outputs
|
||||
if space_data:
|
||||
if "id" in space_data:
|
||||
yield "id", space_data.get("id")
|
||||
|
||||
if "title" in space_data:
|
||||
yield "title", space_data.get("title")
|
||||
|
||||
if "host_ids" in space_data:
|
||||
yield "host_ids", space_data.get("host_ids")
|
||||
|
||||
if space_data:
|
||||
yield "data", space_data
|
||||
if includes:
|
||||
yield "includes", includes
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
# Not tested yet, might have some problem
|
||||
class TwitterGetSpaceBuyersBlock(Block):
|
||||
"""
|
||||
Gets list of users who purchased a ticket to the requested Space
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["spaces.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup buyers for",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
buyer_ids: list[str] = SchemaField(description="List of buyer IDs")
|
||||
usernames: list[str] = SchemaField(description="List of buyer usernames")
|
||||
|
||||
# Complete outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete space buyers data")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c1c121a8-a62f-11ef-8b0e-d7b85f96a46f",
|
||||
description="This block retrieves a list of users who purchased tickets to a Twitter Space.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetSpaceBuyersBlock.Input,
|
||||
output_schema=TwitterGetSpaceBuyersBlock.Output,
|
||||
test_input={
|
||||
"space_id": "1DXxyRYNejbKM",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("buyer_ids", ["2244994945"]),
|
||||
("usernames", ["testuser"]),
|
||||
(
|
||||
"data",
|
||||
[{"id": "2244994945", "username": "testuser", "name": "Test User"}],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_space_buyers": lambda *args, **kwargs: (
|
||||
[{"id": "2244994945", "username": "testuser", "name": "Test User"}],
|
||||
{},
|
||||
["2244994945"],
|
||||
["testuser"],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_space_buyers(
|
||||
credentials: TwitterCredentials,
|
||||
space_id: str,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": space_id,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_space_buyers(**params))
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
buyer_ids = [buyer["id"] for buyer in data]
|
||||
usernames = [buyer["username"] for buyer in data]
|
||||
|
||||
return data, included, buyer_ids, usernames
|
||||
|
||||
raise Exception("No buyers found for this Space")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
buyers_data, included, buyer_ids, usernames = self.get_space_buyers(
|
||||
credentials,
|
||||
input_data.space_id,
|
||||
input_data.expansions,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
if buyer_ids:
|
||||
yield "buyer_ids", buyer_ids
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
|
||||
if buyers_data:
|
||||
yield "data", buyers_data
|
||||
if included:
|
||||
yield "includes", included
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetSpaceTweetsBlock(Block):
|
||||
"""
|
||||
Gets list of Tweets shared in the requested Space
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["spaces.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
space_id: str = SchemaField(
|
||||
description="Space ID to lookup tweets for",
|
||||
placeholder="Enter Space ID",
|
||||
required=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
tweet_ids: list[str] = SchemaField(description="List of tweet IDs")
|
||||
texts: list[str] = SchemaField(description="List of tweet texts")
|
||||
|
||||
# Complete outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete space tweets data")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Response metadata")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b69731e6-a62f-11ef-b2d4-1bf14dd6aee4",
|
||||
description="This block retrieves tweets shared in a Twitter Space.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetSpaceTweetsBlock.Input,
|
||||
output_schema=TwitterGetSpaceTweetsBlock.Output,
|
||||
test_input={
|
||||
"space_id": "1DXxyRYNejbKM",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("tweet_ids", ["1234567890"]),
|
||||
("texts", ["Test tweet"]),
|
||||
("data", [{"id": "1234567890", "text": "Test tweet"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_space_tweets": lambda *args, **kwargs: (
|
||||
[{"id": "1234567890", "text": "Test tweet"}], # data
|
||||
{},
|
||||
["1234567890"],
|
||||
["Test tweet"],
|
||||
{},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_space_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
space_id: str,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": space_id,
|
||||
}
|
||||
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_space_tweets(**params))
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
tweet_ids = [str(tweet["id"]) for tweet in data]
|
||||
texts = [tweet["text"] for tweet in data]
|
||||
|
||||
meta = response.meta or {}
|
||||
|
||||
return data, included, tweet_ids, texts, meta
|
||||
|
||||
raise Exception("No tweets found for this Space")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
tweets_data, included, tweet_ids, texts, meta = self.get_space_tweets(
|
||||
credentials,
|
||||
input_data.space_id,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
if tweet_ids:
|
||||
yield "tweet_ids", tweet_ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
|
||||
if tweets_data:
|
||||
yield "data", tweets_data
|
||||
if included:
|
||||
yield "includes", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,20 +0,0 @@
|
||||
import tweepy
|
||||
|
||||
|
||||
def handle_tweepy_exception(e: Exception) -> str:
|
||||
if isinstance(e, tweepy.BadRequest):
|
||||
return f"Bad Request (400): {str(e)}"
|
||||
elif isinstance(e, tweepy.Unauthorized):
|
||||
return f"Unauthorized (401): {str(e)}"
|
||||
elif isinstance(e, tweepy.Forbidden):
|
||||
return f"Forbidden (403): {str(e)}"
|
||||
elif isinstance(e, tweepy.NotFound):
|
||||
return f"Not Found (404): {str(e)}"
|
||||
elif isinstance(e, tweepy.TooManyRequests):
|
||||
return f"Too Many Requests (429): {str(e)}"
|
||||
elif isinstance(e, tweepy.TwitterServerError):
|
||||
return f"Twitter Server Error (5xx): {str(e)}"
|
||||
elif isinstance(e, tweepy.TweepyException):
|
||||
return f"Tweepy Error: {str(e)}"
|
||||
else:
|
||||
return f"Unexpected error: {str(e)}"
|
||||
@@ -1,372 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import TweetExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterBookmarkTweetBlock(Block):
|
||||
"""
|
||||
Bookmark a tweet on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "bookmark.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to bookmark",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the bookmark was successful")
|
||||
error: str = SchemaField(description="Error message if the bookmark failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f33d67be-a62f-11ef-a797-ff83ec29ee8e",
|
||||
description="This block bookmarks a tweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterBookmarkTweetBlock.Input,
|
||||
output_schema=TwitterBookmarkTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"bookmark_tweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def bookmark_tweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.bookmark(tweet_id)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.bookmark_tweet(credentials, input_data.tweet_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetBookmarkedTweetsBlock(Block):
|
||||
"""
|
||||
Get All your bookmarked tweets from Twitter
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "bookmark.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
id: list[str] = SchemaField(description="All Tweet IDs")
|
||||
text: list[str] = SchemaField(description="All Tweet texts")
|
||||
userId: list[str] = SchemaField(description="IDs of the tweet authors")
|
||||
userName: list[str] = SchemaField(description="Usernames of the tweet authors")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ed26783e-a62f-11ef-9a21-c77c57dd8a1f",
|
||||
description="This block retrieves bookmarked tweets from Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetBookmarkedTweetsBlock.Input,
|
||||
output_schema=TwitterGetBookmarkedTweetsBlock.Output,
|
||||
test_input={
|
||||
"max_results": 2,
|
||||
"pagination_token": None,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", ["1234567890"]),
|
||||
("text", ["Test tweet"]),
|
||||
("userId", ["12345"]),
|
||||
("userName", ["testuser"]),
|
||||
("data", [{"id": "1234567890", "text": "Test tweet"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_bookmarked_tweets": lambda *args, **kwargs: (
|
||||
["1234567890"],
|
||||
["Test tweet"],
|
||||
["12345"],
|
||||
["testuser"],
|
||||
[{"id": "1234567890", "text": "Test tweet"}],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_bookmarked_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
}
|
||||
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(
|
||||
Response,
|
||||
client.get_bookmarks(**params),
|
||||
)
|
||||
|
||||
meta = {}
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if "users" in included:
|
||||
for user in included["users"]:
|
||||
user_ids.append(str(user["id"]))
|
||||
user_names.append(user["username"])
|
||||
|
||||
return (
|
||||
tweet_ids,
|
||||
tweet_texts,
|
||||
user_ids,
|
||||
user_names,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
raise Exception("No bookmarked tweets found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta, next_token = (
|
||||
self.get_bookmarked_tweets(
|
||||
credentials,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
if ids:
|
||||
yield "id", ids
|
||||
if texts:
|
||||
yield "text", texts
|
||||
if user_ids:
|
||||
yield "userId", user_ids
|
||||
if user_names:
|
||||
yield "userName", user_names
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterRemoveBookmarkTweetBlock(Block):
|
||||
"""
|
||||
Remove a bookmark for a tweet on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "bookmark.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to remove bookmark from",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the bookmark was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the bookmark removal failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e4100684-a62f-11ef-9be9-770cb41a2616",
|
||||
description="This block removes a bookmark from a tweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterRemoveBookmarkTweetBlock.Input,
|
||||
output_schema=TwitterRemoveBookmarkTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"remove_bookmark_tweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def remove_bookmark_tweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.remove_bookmark(tweet_id)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.remove_bookmark_tweet(credentials, input_data.tweet_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,154 +0,0 @@
|
||||
import tweepy
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterHideReplyBlock(Block):
|
||||
"""
|
||||
Hides a reply of one of your tweets
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.moderate.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet reply to hide",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="07d58b3e-a630-11ef-a030-93701d1a465e",
|
||||
description="This block hides a reply to a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterHideReplyBlock.Input,
|
||||
output_schema=TwitterHideReplyBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"hide_reply": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def hide_reply(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.hide_reply(id=tweet_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.hide_reply(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterUnhideReplyBlock(Block):
|
||||
"""
|
||||
Unhides a reply to a tweet
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.moderate.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet reply to unhide",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fcf9e4e4-a62f-11ef-9d85-57d3d06b616a",
|
||||
description="This block unhides a reply to a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnhideReplyBlock.Input,
|
||||
output_schema=TwitterUnhideReplyBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unhide_reply": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unhide_reply(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unhide_reply(id=tweet_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unhide_reply(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,576 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import (
|
||||
TweetExpansionsBuilder,
|
||||
UserExpansionsBuilder,
|
||||
)
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterLikeTweetBlock(Block):
|
||||
"""
|
||||
Likes a tweet
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "like.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to like",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4d0b4c5c-a630-11ef-8e08-1b14c507b347",
|
||||
description="This block likes a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterLikeTweetBlock.Input,
|
||||
output_schema=TwitterLikeTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"like_tweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def like_tweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.like(tweet_id=tweet_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.like_tweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetLikingUsersBlock(Block):
|
||||
"""
|
||||
Gets information about users who liked a one of your tweet
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "like.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to get liking users for",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for getting next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
id: list[str] = SchemaField(description="All User IDs who liked the tweet")
|
||||
username: list[str] = SchemaField(
|
||||
description="All User usernames who liked the tweet"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="34275000-a630-11ef-b01e-5f00d9077c08",
|
||||
description="This block gets information about users who liked a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetLikingUsersBlock.Input,
|
||||
output_schema=TwitterGetLikingUsersBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"max_results": 1,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", ["1234567890"]),
|
||||
("username", ["testuser"]),
|
||||
("data", [{"id": "1234567890", "username": "testuser"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_liking_users": lambda *args, **kwargs: (
|
||||
["1234567890"],
|
||||
["testuser"],
|
||||
[{"id": "1234567890", "username": "testuser"}],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_liking_users(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": tweet_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_liking_users(**params))
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No liking users found")
|
||||
|
||||
meta = {}
|
||||
user_ids = []
|
||||
usernames = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
user_ids = [str(user.id) for user in response.data]
|
||||
usernames = [user.username for user in response.data]
|
||||
|
||||
return user_ids, usernames, data, included, meta, next_token
|
||||
|
||||
raise Exception("No liking users found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, usernames, data, included, meta, next_token = self.get_liking_users(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "id", ids
|
||||
if usernames:
|
||||
yield "username", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetLikedTweetsBlock(Block):
|
||||
"""
|
||||
Gets information about tweets liked by you
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "like.read", "offline.access"]
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="ID of the user to get liked tweets for",
|
||||
placeholder="Enter user ID",
|
||||
)
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (5-100)",
|
||||
placeholder="100",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for getting next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="All Tweet IDs")
|
||||
texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
userIds: list[str] = SchemaField(
|
||||
description="List of user ids that authored the tweets"
|
||||
)
|
||||
userNames: list[str] = SchemaField(
|
||||
description="List of user names that authored the tweets"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="292e7c78-a630-11ef-9f40-df5dffaca106",
|
||||
description="This block gets information about tweets liked by a user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetLikedTweetsBlock.Input,
|
||||
output_schema=TwitterGetLikedTweetsBlock.Output,
|
||||
test_input={
|
||||
"user_id": "1234567890",
|
||||
"max_results": 2,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["12345", "67890"]),
|
||||
("texts", ["Tweet 1", "Tweet 2"]),
|
||||
("userIds", ["67890", "67891"]),
|
||||
("userNames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "12345", "text": "Tweet 1"},
|
||||
{"id": "67890", "text": "Tweet 2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_liked_tweets": lambda *args, **kwargs: (
|
||||
["12345", "67890"],
|
||||
["Tweet 1", "Tweet 2"],
|
||||
["67890", "67891"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "12345", "text": "Tweet 1"},
|
||||
{"id": "67890", "text": "Tweet 2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_liked_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
user_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_liked_tweets(**params))
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No liked tweets found")
|
||||
|
||||
meta = {}
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if "users" in response.includes:
|
||||
user_ids = [str(user["id"]) for user in response.includes["users"]]
|
||||
user_names = [
|
||||
user["username"] for user in response.includes["users"]
|
||||
]
|
||||
|
||||
return (
|
||||
tweet_ids,
|
||||
tweet_texts,
|
||||
user_ids,
|
||||
user_names,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
raise Exception("No liked tweets found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta, next_token = (
|
||||
self.get_liked_tweets(
|
||||
credentials,
|
||||
input_data.user_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if user_ids:
|
||||
yield "userIds", user_ids
|
||||
if user_names:
|
||||
yield "userNames", user_names
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterUnlikeTweetBlock(Block):
|
||||
"""
|
||||
Unlikes a tweet that was previously liked
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "like.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to unlike",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the operation was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1ed5eab8-a630-11ef-8e21-cbbbc80cbb85",
|
||||
description="This block unlikes a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnlikeTweetBlock.Input,
|
||||
output_schema=TwitterUnlikeTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unlike_tweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unlike_tweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unlike(tweet_id=tweet_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unlike_tweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,545 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional, Union, cast
|
||||
|
||||
import tweepy
|
||||
from pydantic import BaseModel
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import (
|
||||
TweetDurationBuilder,
|
||||
TweetExpansionsBuilder,
|
||||
TweetPostBuilder,
|
||||
TweetSearchBuilder,
|
||||
)
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetReplySettingsFilter,
|
||||
TweetTimeWindowInputs,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class Media(BaseModel):
|
||||
discriminator: Literal["media"]
|
||||
media_ids: Optional[List[str]] = None
|
||||
media_tagged_user_ids: Optional[List[str]] = None
|
||||
|
||||
|
||||
class DeepLink(BaseModel):
|
||||
discriminator: Literal["deep_link"]
|
||||
direct_message_deep_link: Optional[str] = None
|
||||
|
||||
|
||||
class Poll(BaseModel):
|
||||
discriminator: Literal["poll"]
|
||||
poll_options: Optional[List[str]] = None
|
||||
poll_duration_minutes: Optional[int] = None
|
||||
|
||||
|
||||
class Place(BaseModel):
|
||||
discriminator: Literal["place"]
|
||||
place_id: Optional[str] = None
|
||||
|
||||
|
||||
class Quote(BaseModel):
|
||||
discriminator: Literal["quote"]
|
||||
quote_tweet_id: Optional[str] = None
|
||||
|
||||
|
||||
class TwitterPostTweetBlock(Block):
|
||||
"""
|
||||
Create a tweet on Twitter with the option to include one additional element such as a media, quote, or deep link.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_text: str | None = SchemaField(
|
||||
description="Text of the tweet to post",
|
||||
placeholder="Enter your tweet",
|
||||
default=None,
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
for_super_followers_only: bool = SchemaField(
|
||||
description="Tweet exclusively for Super Followers",
|
||||
placeholder="Enter for super followers only",
|
||||
advanced=True,
|
||||
default=False,
|
||||
)
|
||||
|
||||
attachment: Union[Media, DeepLink, Poll, Place, Quote] | None = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Additional tweet data (media, deep link, poll, place or quote)",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
exclude_reply_user_ids: Optional[List[str]] = SchemaField(
|
||||
description="User IDs to exclude from reply Tweet thread. [ex - 6253282]",
|
||||
placeholder="Enter user IDs to exclude",
|
||||
advanced=True,
|
||||
default=None,
|
||||
)
|
||||
|
||||
in_reply_to_tweet_id: Optional[str] = SchemaField(
|
||||
description="Tweet ID being replied to. Please note that in_reply_to_tweet_id needs to be in the request if exclude_reply_user_ids is present",
|
||||
default=None,
|
||||
placeholder="Enter in reply to tweet ID",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
reply_settings: TweetReplySettingsFilter = SchemaField(
|
||||
description="Who can reply to the Tweet (mentionedUsers or following)",
|
||||
placeholder="Enter reply settings",
|
||||
advanced=True,
|
||||
default=TweetReplySettingsFilter(All_Users=True),
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
tweet_id: str = SchemaField(description="ID of the created tweet")
|
||||
tweet_url: str = SchemaField(description="URL to the tweet")
|
||||
error: str = SchemaField(
|
||||
description="Error message if the tweet posting failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="7bb0048a-a630-11ef-aeb8-abc0dadb9b12",
|
||||
description="This block posts a tweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterPostTweetBlock.Input,
|
||||
output_schema=TwitterPostTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_text": "This is a test tweet.",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"attachment": {
|
||||
"discriminator": "deep_link",
|
||||
"direct_message_deep_link": "https://twitter.com/messages/compose",
|
||||
},
|
||||
"for_super_followers_only": False,
|
||||
"exclude_reply_user_ids": [],
|
||||
"in_reply_to_tweet_id": "",
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("tweet_id", "1234567890"),
|
||||
("tweet_url", "https://twitter.com/user/status/1234567890"),
|
||||
],
|
||||
test_mock={
|
||||
"post_tweet": lambda *args, **kwargs: (
|
||||
"1234567890",
|
||||
"https://twitter.com/user/status/1234567890",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def post_tweet(
|
||||
self,
|
||||
credentials: TwitterCredentials,
|
||||
input_txt: str | None,
|
||||
attachment: Union[Media, DeepLink, Poll, Place, Quote] | None,
|
||||
for_super_followers_only: bool,
|
||||
exclude_reply_user_ids: Optional[List[str]],
|
||||
in_reply_to_tweet_id: Optional[str],
|
||||
reply_settings: TweetReplySettingsFilter,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = (
|
||||
TweetPostBuilder()
|
||||
.add_text(input_txt)
|
||||
.add_super_followers(for_super_followers_only)
|
||||
.add_reply_settings(
|
||||
exclude_reply_user_ids or [],
|
||||
in_reply_to_tweet_id or "",
|
||||
reply_settings,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(attachment, Media):
|
||||
params.add_media(
|
||||
attachment.media_ids or [], attachment.media_tagged_user_ids or []
|
||||
)
|
||||
elif isinstance(attachment, DeepLink):
|
||||
params.add_deep_link(attachment.direct_message_deep_link or "")
|
||||
elif isinstance(attachment, Poll):
|
||||
params.add_poll_options(attachment.poll_options or [])
|
||||
params.add_poll_duration(attachment.poll_duration_minutes or 0)
|
||||
elif isinstance(attachment, Place):
|
||||
params.add_place(attachment.place_id or "")
|
||||
elif isinstance(attachment, Quote):
|
||||
params.add_quote(attachment.quote_tweet_id or "")
|
||||
|
||||
tweet = cast(Response, client.create_tweet(**params.build()))
|
||||
|
||||
if not tweet.data:
|
||||
raise Exception("Failed to create tweet")
|
||||
|
||||
tweet_id = tweet.data["id"]
|
||||
tweet_url = f"https://twitter.com/user/status/{tweet_id}"
|
||||
return str(tweet_id), tweet_url
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
tweet_id, tweet_url = self.post_tweet(
|
||||
credentials,
|
||||
input_data.tweet_text,
|
||||
input_data.attachment,
|
||||
input_data.for_super_followers_only,
|
||||
input_data.exclude_reply_user_ids,
|
||||
input_data.in_reply_to_tweet_id,
|
||||
input_data.reply_settings,
|
||||
)
|
||||
yield "tweet_id", tweet_id
|
||||
yield "tweet_url", tweet_url
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterDeleteTweetBlock(Block):
|
||||
"""
|
||||
Deletes a tweet on Twitter using twitter Id
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to delete",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the tweet was successfully deleted"
|
||||
)
|
||||
error: str = SchemaField(
|
||||
description="Error message if the tweet deletion failed"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="761babf0-a630-11ef-a03d-abceb082f58f",
|
||||
description="This block deletes a tweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterDeleteTweetBlock.Input,
|
||||
output_schema=TwitterDeleteTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"delete_tweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def delete_tweet(credentials: TwitterCredentials, tweet_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
client.delete_tweet(id=tweet_id, user_auth=False)
|
||||
return True
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.delete_tweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterSearchRecentTweetsBlock(Block):
|
||||
"""
|
||||
Searches all public Tweets in Twitter history
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
query: str = SchemaField(
|
||||
description="Search query (up to 1024 characters)",
|
||||
placeholder="Enter search query",
|
||||
)
|
||||
|
||||
max_results: int = SchemaField(
|
||||
description="Maximum number of results per page (10-500)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination: str | None = SchemaField(
|
||||
description="Token for pagination",
|
||||
default="",
|
||||
placeholder="Enter pagination token",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
tweet_ids: list[str] = SchemaField(description="All Tweet IDs")
|
||||
tweet_texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="53e5cf8e-a630-11ef-ba85-df6d666fa5d5",
|
||||
description="This block searches all public Tweets in Twitter history.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterSearchRecentTweetsBlock.Input,
|
||||
output_schema=TwitterSearchRecentTweetsBlock.Output,
|
||||
test_input={
|
||||
"query": "from:twitterapi #twitterapi",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_results": 2,
|
||||
"start_time": "2024-12-14T18:30:00.000Z",
|
||||
"end_time": "2024-12-17T18:30:00.000Z",
|
||||
"since_id": None,
|
||||
"until_id": None,
|
||||
"sort_order": None,
|
||||
"pagination": None,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("tweet_ids", ["1373001119480344583", "1372627771717869568"]),
|
||||
(
|
||||
"tweet_texts",
|
||||
[
|
||||
"Looking to get started with the Twitter API but new to APIs in general?",
|
||||
"Thanks to everyone who joined and made today a great session!",
|
||||
],
|
||||
),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{
|
||||
"id": "1373001119480344583",
|
||||
"text": "Looking to get started with the Twitter API but new to APIs in general?",
|
||||
},
|
||||
{
|
||||
"id": "1372627771717869568",
|
||||
"text": "Thanks to everyone who joined and made today a great session!",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"search_tweets": lambda *args, **kwargs: (
|
||||
["1373001119480344583", "1372627771717869568"],
|
||||
[
|
||||
"Looking to get started with the Twitter API but new to APIs in general?",
|
||||
"Thanks to everyone who joined and made today a great session!",
|
||||
],
|
||||
[
|
||||
{
|
||||
"id": "1373001119480344583",
|
||||
"text": "Looking to get started with the Twitter API but new to APIs in general?",
|
||||
},
|
||||
{
|
||||
"id": "1372627771717869568",
|
||||
"text": "Thanks to everyone who joined and made today a great session!",
|
||||
},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def search_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
query: str,
|
||||
max_results: int,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
since_id: str | None,
|
||||
until_id: str | None,
|
||||
sort_order: str | None,
|
||||
pagination: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
# Building common params
|
||||
params = (
|
||||
TweetSearchBuilder()
|
||||
.add_query(query)
|
||||
.add_pagination(max_results, pagination)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Adding time window to params If required by the user
|
||||
params = (
|
||||
TweetDurationBuilder(params)
|
||||
.add_start_time(start_time)
|
||||
.add_end_time(end_time)
|
||||
.add_since_id(since_id)
|
||||
.add_until_id(until_id)
|
||||
.add_sort_order(sort_order)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.search_recent_tweets(**params))
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No tweets found")
|
||||
|
||||
meta = {}
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
return tweet_ids, tweet_texts, data, included, meta, next_token
|
||||
|
||||
raise Exception("No tweets found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, data, included, meta, next_token = self.search_tweets(
|
||||
credentials,
|
||||
input_data.query,
|
||||
input_data.max_results,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.since_id,
|
||||
input_data.until_id,
|
||||
input_data.sort_order,
|
||||
input_data.pagination,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "tweet_ids", ids
|
||||
if texts:
|
||||
yield "tweet_texts", texts
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,222 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import TweetExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExcludesFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterGetQuoteTweetsBlock(Block):
|
||||
"""
|
||||
Gets quote tweets for a specified tweet ID
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to get quotes for",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Number of results to return (max 100)",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
exclude: TweetExcludesFilter | None = SchemaField(
|
||||
description="Types of tweets to exclude", advanced=True, default=None
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination",
|
||||
advanced=True,
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list = SchemaField(description="All Tweet IDs ")
|
||||
texts: list = SchemaField(description="All Tweet texts")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9fbdd208-a630-11ef-9b97-ab7a3a695ca3",
|
||||
description="This block gets quote tweets for a specific tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetQuoteTweetsBlock.Input,
|
||||
output_schema=TwitterGetQuoteTweetsBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"max_results": 2,
|
||||
"pagination_token": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["12345", "67890"]),
|
||||
("texts", ["Tweet 1", "Tweet 2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "12345", "text": "Tweet 1"},
|
||||
{"id": "67890", "text": "Tweet 2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_quote_tweets": lambda *args, **kwargs: (
|
||||
["12345", "67890"],
|
||||
["Tweet 1", "Tweet 2"],
|
||||
[
|
||||
{"id": "12345", "text": "Tweet 1"},
|
||||
{"id": "67890", "text": "Tweet 2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_quote_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
max_results: int | None,
|
||||
exclude: TweetExcludesFilter | None,
|
||||
pagination_token: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": tweet_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"exclude": None if exclude == TweetExcludesFilter() else exclude,
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_quote_tweets(**params))
|
||||
|
||||
meta = {}
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
return tweet_ids, tweet_texts, data, included, meta, next_token
|
||||
|
||||
raise Exception("No quote tweets found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, data, included, meta, next_token = self.get_quote_tweets(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
input_data.max_results,
|
||||
input_data.exclude,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,363 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterRetweetBlock(Block):
|
||||
"""
|
||||
Retweets a tweet on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to retweet",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the retweet was successful")
|
||||
error: str = SchemaField(description="Error message if the retweet failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="bd7b8d3a-a630-11ef-be96-6f4aa4c3c4f4",
|
||||
description="This block retweets a tweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterRetweetBlock.Input,
|
||||
output_schema=TwitterRetweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"retweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def retweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.retweet(
|
||||
tweet_id=tweet_id,
|
||||
user_auth=False,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.retweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterRemoveRetweetBlock(Block):
|
||||
"""
|
||||
Removes a retweet on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "tweet.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to remove retweet",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the retweet was successfully removed"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the removal failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="b6e663f0-a630-11ef-a7f0-8b9b0c542ff8",
|
||||
description="This block removes a retweet on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterRemoveRetweetBlock.Input,
|
||||
output_schema=TwitterRemoveRetweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"remove_retweet": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def remove_retweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unretweet(
|
||||
source_tweet_id=tweet_id,
|
||||
user_auth=False,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.remove_retweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetRetweetersBlock(Block):
|
||||
"""
|
||||
Gets information about who has retweeted a tweet
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="ID of the tweet to get retweeters for",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results per page (1-100)",
|
||||
default=10,
|
||||
placeholder="Enter max results",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list = SchemaField(description="List of user ids who retweeted")
|
||||
names: list = SchemaField(description="List of user names who retweeted")
|
||||
usernames: list = SchemaField(
|
||||
description="List of user usernames who retweeted"
|
||||
)
|
||||
next_token: str = SchemaField(description="Token for next page of results")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="ad7aa6fa-a630-11ef-a6b0-e7ca640aa030",
|
||||
description="This block gets information about who has retweeted a tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetRetweetersBlock.Input,
|
||||
output_schema=TwitterGetRetweetersBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1234567890",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_results": 1,
|
||||
"pagination_token": "",
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["12345"]),
|
||||
("names", ["Test User"]),
|
||||
("usernames", ["testuser"]),
|
||||
(
|
||||
"data",
|
||||
[{"id": "12345", "name": "Test User", "username": "testuser"}],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_retweeters": lambda *args, **kwargs: (
|
||||
[{"id": "12345", "name": "Test User", "username": "testuser"}],
|
||||
{},
|
||||
{},
|
||||
["12345"],
|
||||
["Test User"],
|
||||
["testuser"],
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_retweeters(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": tweet_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_retweeters(**params))
|
||||
|
||||
meta = {}
|
||||
ids = []
|
||||
names = []
|
||||
usernames = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
ids = [str(user.id) for user in response.data]
|
||||
names = [user.name for user in response.data]
|
||||
usernames = [user.username for user in response.data]
|
||||
return data, included, meta, ids, names, usernames, next_token
|
||||
|
||||
raise Exception("No retweeters found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, meta, ids, names, usernames, next_token = (
|
||||
self.get_retweeters(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if names:
|
||||
yield "names", names
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,757 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import (
|
||||
TweetDurationBuilder,
|
||||
TweetExpansionsBuilder,
|
||||
)
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetTimeWindowInputs,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterGetUserMentionsBlock(Block):
|
||||
"""
|
||||
Returns Tweets where a single user is mentioned, just put that user id
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="Unique identifier of the user for whom to return Tweets mentioning the user",
|
||||
placeholder="Enter user ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Number of tweets to retrieve (5-100)",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="List of Tweet IDs")
|
||||
texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
|
||||
userIds: list[str] = SchemaField(
|
||||
description="List of user ids that mentioned the user"
|
||||
)
|
||||
userNames: list[str] = SchemaField(
|
||||
description="List of user names that mentioned the user"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e01c890c-a630-11ef-9e20-37da24888bd0",
|
||||
description="This block retrieves Tweets mentioning a specific user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetUserMentionsBlock.Input,
|
||||
output_schema=TwitterGetUserMentionsBlock.Output,
|
||||
test_input={
|
||||
"user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_results": 2,
|
||||
"start_time": "2024-12-14T18:30:00.000Z",
|
||||
"end_time": "2024-12-17T18:30:00.000Z",
|
||||
"since_id": "",
|
||||
"until_id": "",
|
||||
"sort_order": None,
|
||||
"pagination_token": None,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1373001119480344583", "1372627771717869568"]),
|
||||
("texts", ["Test mention 1", "Test mention 2"]),
|
||||
("userIds", ["67890", "67891"]),
|
||||
("userNames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test mention 1"},
|
||||
{"id": "1372627771717869568", "text": "Test mention 2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_mentions": lambda *args, **kwargs: (
|
||||
["1373001119480344583", "1372627771717869568"],
|
||||
["Test mention 1", "Test mention 2"],
|
||||
["67890", "67891"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test mention 1"},
|
||||
{"id": "1372627771717869568", "text": "Test mention 2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_mentions(
|
||||
credentials: TwitterCredentials,
|
||||
user_id: str,
|
||||
max_results: int | None,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
since_id: str | None,
|
||||
until_id: str | None,
|
||||
sort_order: str | None,
|
||||
pagination: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": None if pagination == "" else pagination,
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Adding time window to params If required by the user
|
||||
params = (
|
||||
TweetDurationBuilder(params)
|
||||
.add_start_time(start_time)
|
||||
.add_end_time(end_time)
|
||||
.add_since_id(since_id)
|
||||
.add_until_id(until_id)
|
||||
.add_sort_order(sort_order)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(
|
||||
Response,
|
||||
client.get_users_mentions(**params),
|
||||
)
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No tweets found")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
meta = response.meta or {}
|
||||
next_token = meta.get("next_token", "")
|
||||
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if "users" in included:
|
||||
user_ids = [str(user["id"]) for user in included["users"]]
|
||||
user_names = [user["username"] for user in included["users"]]
|
||||
|
||||
return (
|
||||
tweet_ids,
|
||||
tweet_texts,
|
||||
user_ids,
|
||||
user_names,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta, next_token = (
|
||||
self.get_mentions(
|
||||
credentials,
|
||||
input_data.user_id,
|
||||
input_data.max_results,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.since_id,
|
||||
input_data.until_id,
|
||||
input_data.sort_order,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if user_ids:
|
||||
yield "userIds", user_ids
|
||||
if user_names:
|
||||
yield "userNames", user_names
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetHomeTimelineBlock(Block):
|
||||
"""
|
||||
Returns a collection of the most recent Tweets and Retweets posted by you and users you follow
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Number of tweets to retrieve (5-100)",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="List of Tweet IDs")
|
||||
texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
|
||||
userIds: list[str] = SchemaField(
|
||||
description="List of user ids that authored the tweets"
|
||||
)
|
||||
userNames: list[str] = SchemaField(
|
||||
description="List of user names that authored the tweets"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d222a070-a630-11ef-a18a-3f52f76c6962",
|
||||
description="This block retrieves the authenticated user's home timeline.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetHomeTimelineBlock.Input,
|
||||
output_schema=TwitterGetHomeTimelineBlock.Output,
|
||||
test_input={
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_results": 2,
|
||||
"start_time": "2024-12-14T18:30:00.000Z",
|
||||
"end_time": "2024-12-17T18:30:00.000Z",
|
||||
"since_id": None,
|
||||
"until_id": None,
|
||||
"sort_order": None,
|
||||
"pagination_token": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1373001119480344583", "1372627771717869568"]),
|
||||
("texts", ["Test tweet 1", "Test tweet 2"]),
|
||||
("userIds", ["67890", "67891"]),
|
||||
("userNames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test tweet 1"},
|
||||
{"id": "1372627771717869568", "text": "Test tweet 2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_timeline": lambda *args, **kwargs: (
|
||||
["1373001119480344583", "1372627771717869568"],
|
||||
["Test tweet 1", "Test tweet 2"],
|
||||
["67890", "67891"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test tweet 1"},
|
||||
{"id": "1372627771717869568", "text": "Test tweet 2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_timeline(
|
||||
credentials: TwitterCredentials,
|
||||
max_results: int | None,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
since_id: str | None,
|
||||
until_id: str | None,
|
||||
sort_order: str | None,
|
||||
pagination: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"max_results": max_results,
|
||||
"pagination_token": None if pagination == "" else pagination,
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Adding time window to params If required by the user
|
||||
params = (
|
||||
TweetDurationBuilder(params)
|
||||
.add_start_time(start_time)
|
||||
.add_end_time(end_time)
|
||||
.add_since_id(since_id)
|
||||
.add_until_id(until_id)
|
||||
.add_sort_order(sort_order)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(
|
||||
Response,
|
||||
client.get_home_timeline(**params),
|
||||
)
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No tweets found")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
meta = response.meta or {}
|
||||
next_token = meta.get("next_token", "")
|
||||
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if "users" in included:
|
||||
user_ids = [str(user["id"]) for user in included["users"]]
|
||||
user_names = [user["username"] for user in included["users"]]
|
||||
|
||||
return (
|
||||
tweet_ids,
|
||||
tweet_texts,
|
||||
user_ids,
|
||||
user_names,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta, next_token = (
|
||||
self.get_timeline(
|
||||
credentials,
|
||||
input_data.max_results,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.since_id,
|
||||
input_data.until_id,
|
||||
input_data.sort_order,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if user_ids:
|
||||
yield "userIds", user_ids
|
||||
if user_names:
|
||||
yield "userNames", user_names
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetUserTweetsBlock(Block):
|
||||
"""
|
||||
Returns Tweets composed by a single user, specified by the requested user ID
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs, TweetTimeWindowInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
user_id: str = SchemaField(
|
||||
description="Unique identifier of the Twitter account (user ID) for whom to return results",
|
||||
placeholder="Enter user ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Number of tweets to retrieve (5-100)",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for pagination", default="", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="List of Tweet IDs")
|
||||
texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
|
||||
userIds: list[str] = SchemaField(
|
||||
description="List of user ids that authored the tweets"
|
||||
)
|
||||
userNames: list[str] = SchemaField(
|
||||
description="List of user names that authored the tweets"
|
||||
)
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(
|
||||
description="Provides metadata such as pagination info (next_token) or result counts"
|
||||
)
|
||||
|
||||
# error
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c44c3ef2-a630-11ef-9ff7-eb7b5ea3a5cb",
|
||||
description="This block retrieves Tweets composed by a single user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetUserTweetsBlock.Input,
|
||||
output_schema=TwitterGetUserTweetsBlock.Output,
|
||||
test_input={
|
||||
"user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"max_results": 2,
|
||||
"start_time": "2024-12-14T18:30:00.000Z",
|
||||
"end_time": "2024-12-17T18:30:00.000Z",
|
||||
"since_id": None,
|
||||
"until_id": None,
|
||||
"sort_order": None,
|
||||
"pagination_token": None,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1373001119480344583", "1372627771717869568"]),
|
||||
("texts", ["Test tweet 1", "Test tweet 2"]),
|
||||
("userIds", ["67890", "67891"]),
|
||||
("userNames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test tweet 1"},
|
||||
{"id": "1372627771717869568", "text": "Test tweet 2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_user_tweets": lambda *args, **kwargs: (
|
||||
["1373001119480344583", "1372627771717869568"],
|
||||
["Test tweet 1", "Test tweet 2"],
|
||||
["67890", "67891"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "1373001119480344583", "text": "Test tweet 1"},
|
||||
{"id": "1372627771717869568", "text": "Test tweet 2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
user_id: str,
|
||||
max_results: int | None,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
since_id: str | None,
|
||||
until_id: str | None,
|
||||
sort_order: str | None,
|
||||
pagination: str | None,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": None if pagination == "" else pagination,
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Adding time window to params If required by the user
|
||||
params = (
|
||||
TweetDurationBuilder(params)
|
||||
.add_start_time(start_time)
|
||||
.add_end_time(end_time)
|
||||
.add_since_id(since_id)
|
||||
.add_until_id(until_id)
|
||||
.add_sort_order(sort_order)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(
|
||||
Response,
|
||||
client.get_users_tweets(**params),
|
||||
)
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No tweets found")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
meta = response.meta or {}
|
||||
next_token = meta.get("next_token", "")
|
||||
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if "users" in included:
|
||||
user_ids = [str(user["id"]) for user in included["users"]]
|
||||
user_names = [user["username"] for user in included["users"]]
|
||||
|
||||
return (
|
||||
tweet_ids,
|
||||
tweet_texts,
|
||||
user_ids,
|
||||
user_names,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta, next_token = (
|
||||
self.get_user_tweets(
|
||||
credentials,
|
||||
input_data.user_id,
|
||||
input_data.max_results,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.since_id,
|
||||
input_data.until_id,
|
||||
input_data.sort_order,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if user_ids:
|
||||
yield "userIds", user_ids
|
||||
if user_names:
|
||||
yield "userNames", user_names
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,361 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import TweetExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
ExpansionFilter,
|
||||
TweetExpansionInputs,
|
||||
TweetFieldsFilter,
|
||||
TweetMediaFieldsFilter,
|
||||
TweetPlaceFieldsFilter,
|
||||
TweetPollFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterGetTweetBlock(Block):
|
||||
"""
|
||||
Returns information about a single Tweet specified by the requested ID
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_id: str = SchemaField(
|
||||
description="Unique identifier of the Tweet to request (ex: 1460323737035677698)",
|
||||
placeholder="Enter tweet ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
id: str = SchemaField(description="Tweet ID")
|
||||
text: str = SchemaField(description="Tweet text")
|
||||
userId: str = SchemaField(description="ID of the tweet author")
|
||||
userName: str = SchemaField(description="Username of the tweet author")
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: dict = SchemaField(description="Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about the tweet")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="f5155c3a-a630-11ef-9cc1-a309988b4d92",
|
||||
description="This block retrieves information about a specific Tweet.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetTweetBlock.Input,
|
||||
output_schema=TwitterGetTweetBlock.Output,
|
||||
test_input={
|
||||
"tweet_id": "1460323737035677698",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "1460323737035677698"),
|
||||
("text", "Test tweet content"),
|
||||
("userId", "12345"),
|
||||
("userName", "testuser"),
|
||||
("data", {"id": "1460323737035677698", "text": "Test tweet content"}),
|
||||
("included", {"users": [{"id": "12345", "username": "testuser"}]}),
|
||||
("meta", {"result_count": 1}),
|
||||
],
|
||||
test_mock={
|
||||
"get_tweet": lambda *args, **kwargs: (
|
||||
{"id": "1460323737035677698", "text": "Test tweet content"},
|
||||
{"users": [{"id": "12345", "username": "testuser"}]},
|
||||
{"result_count": 1},
|
||||
"12345",
|
||||
"testuser",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_tweet(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_id: str,
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
params = {"id": tweet_id, "user_auth": False}
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_tweet(**params))
|
||||
|
||||
meta = {}
|
||||
user_id = ""
|
||||
user_name = ""
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_dict(response.data)
|
||||
|
||||
if included and "users" in included:
|
||||
user_id = str(included["users"][0]["id"])
|
||||
user_name = included["users"][0]["username"]
|
||||
|
||||
if response.data:
|
||||
return data, included, meta, user_id, user_name
|
||||
|
||||
raise Exception("Tweet not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
|
||||
tweet_data, included, meta, user_id, user_name = self.get_tweet(
|
||||
credentials,
|
||||
input_data.tweet_id,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
|
||||
yield "id", str(tweet_data["id"])
|
||||
yield "text", tweet_data["text"]
|
||||
if user_id:
|
||||
yield "userId", user_id
|
||||
if user_name:
|
||||
yield "userName", user_name
|
||||
yield "data", tweet_data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetTweetsBlock(Block):
|
||||
"""
|
||||
Returns information about multiple Tweets specified by the requested IDs
|
||||
"""
|
||||
|
||||
class Input(TweetExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["tweet.read", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
tweet_ids: list[str] = SchemaField(
|
||||
description="List of Tweet IDs to request (up to 100)",
|
||||
placeholder="Enter tweet IDs",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common Outputs that user commonly uses
|
||||
ids: list[str] = SchemaField(description="All Tweet IDs")
|
||||
texts: list[str] = SchemaField(description="All Tweet texts")
|
||||
userIds: list[str] = SchemaField(
|
||||
description="List of user ids that authored the tweets"
|
||||
)
|
||||
userNames: list[str] = SchemaField(
|
||||
description="List of user names that authored the tweets"
|
||||
)
|
||||
|
||||
# Complete Outputs for advanced use
|
||||
data: list[dict] = SchemaField(description="Complete Tweet data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data that you have requested (Optional) via Expansions field"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata about the tweets")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="e7cc5420-a630-11ef-bfaf-13bdd8096a51",
|
||||
description="This block retrieves information about multiple Tweets.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetTweetsBlock.Input,
|
||||
output_schema=TwitterGetTweetsBlock.Output,
|
||||
test_input={
|
||||
"tweet_ids": ["1460323737035677698"],
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"media_fields": None,
|
||||
"place_fields": None,
|
||||
"poll_fields": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1460323737035677698"]),
|
||||
("texts", ["Test tweet content"]),
|
||||
("userIds", ["67890"]),
|
||||
("userNames", ["testuser1"]),
|
||||
("data", [{"id": "1460323737035677698", "text": "Test tweet content"}]),
|
||||
("included", {"users": [{"id": "67890", "username": "testuser1"}]}),
|
||||
("meta", {"result_count": 1}),
|
||||
],
|
||||
test_mock={
|
||||
"get_tweets": lambda *args, **kwargs: (
|
||||
["1460323737035677698"], # ids
|
||||
["Test tweet content"], # texts
|
||||
["67890"], # user_ids
|
||||
["testuser1"], # user_names
|
||||
[
|
||||
{"id": "1460323737035677698", "text": "Test tweet content"}
|
||||
], # data
|
||||
{"users": [{"id": "67890", "username": "testuser1"}]}, # included
|
||||
{"result_count": 1}, # meta
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_tweets(
|
||||
credentials: TwitterCredentials,
|
||||
tweet_ids: list[str],
|
||||
expansions: ExpansionFilter | None,
|
||||
media_fields: TweetMediaFieldsFilter | None,
|
||||
place_fields: TweetPlaceFieldsFilter | None,
|
||||
poll_fields: TweetPollFieldsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
params = {"ids": tweet_ids, "user_auth": False}
|
||||
|
||||
# Adding expansions to params If required by the user
|
||||
params = (
|
||||
TweetExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_media_fields(media_fields)
|
||||
.add_place_fields(place_fields)
|
||||
.add_poll_fields(poll_fields)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_tweets(**params))
|
||||
|
||||
if not response.data and not response.meta:
|
||||
raise Exception("No tweets found")
|
||||
|
||||
tweet_ids = []
|
||||
tweet_texts = []
|
||||
user_ids = []
|
||||
user_names = []
|
||||
meta = {}
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
tweet_ids = [str(tweet.id) for tweet in response.data]
|
||||
tweet_texts = [tweet.text for tweet in response.data]
|
||||
|
||||
if included and "users" in included:
|
||||
for user in included["users"]:
|
||||
user_ids.append(str(user["id"]))
|
||||
user_names.append(user["username"])
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
|
||||
return tweet_ids, tweet_texts, user_ids, user_names, data, included, meta
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, texts, user_ids, user_names, data, included, meta = self.get_tweets(
|
||||
credentials,
|
||||
input_data.tweet_ids,
|
||||
input_data.expansions,
|
||||
input_data.media_fields,
|
||||
input_data.place_fields,
|
||||
input_data.poll_fields,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if texts:
|
||||
yield "texts", texts
|
||||
if user_ids:
|
||||
yield "userIds", user_ids
|
||||
if user_names:
|
||||
yield "userNames", user_names
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,305 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import IncludesSerializer
|
||||
from backend.blocks.twitter._types import (
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterUnblockUserBlock(Block):
|
||||
"""
|
||||
Unblock a specific user on Twitter. The request succeeds with no action when the user sends a request to a user they're not blocking or have already unblocked.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["block.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to unblock",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the unblock was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="0f1b6570-a631-11ef-a3ea-230cbe9650dd",
|
||||
description="This block unblocks a specific user on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnblockUserBlock.Input,
|
||||
output_schema=TwitterUnblockUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unblock_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unblock_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unblock(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unblock_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetBlockedUsersBlock(Block):
|
||||
"""
|
||||
Get a list of users who are blocked by the authenticating user
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access", "block.read"]
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-1000, default 100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for retrieving next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
user_ids: list[str] = SchemaField(description="List of blocked user IDs")
|
||||
usernames_: list[str] = SchemaField(description="List of blocked usernames")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="05f409e8-a631-11ef-ae89-93de863ee30d",
|
||||
description="This block retrieves a list of users blocked by the authenticating user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetBlockedUsersBlock.Input,
|
||||
output_schema=TwitterGetBlockedUsersBlock.Output,
|
||||
test_input={
|
||||
"max_results": 10,
|
||||
"pagination_token": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("user_ids", ["12345", "67890"]),
|
||||
("usernames_", ["testuser1", "testuser2"]),
|
||||
],
|
||||
test_mock={
|
||||
"get_blocked_users": lambda *args, **kwargs: (
|
||||
{}, # included
|
||||
{}, # meta
|
||||
["12345", "67890"], # user_ids
|
||||
["testuser1", "testuser2"], # usernames
|
||||
None, # next_token
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_blocked_users(
|
||||
credentials: TwitterCredentials,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_blocked(**params))
|
||||
|
||||
meta = {}
|
||||
user_ids = []
|
||||
usernames = []
|
||||
next_token = None
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
|
||||
if response.data:
|
||||
for user in response.data:
|
||||
user_ids.append(str(user.id))
|
||||
usernames.append(user.username)
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
if "next_token" in meta:
|
||||
next_token = meta["next_token"]
|
||||
|
||||
if user_ids and usernames:
|
||||
return included, meta, user_ids, usernames, next_token
|
||||
else:
|
||||
raise tweepy.TweepyException("No blocked users found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
included, meta, user_ids, usernames, next_token = self.get_blocked_users(
|
||||
credentials,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if user_ids:
|
||||
yield "user_ids", user_ids
|
||||
if usernames:
|
||||
yield "usernames_", usernames
|
||||
if included:
|
||||
yield "included", included
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterBlockUserBlock(Block):
|
||||
"""
|
||||
Block a specific user on Twitter
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["block.write", "users.read", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to block",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(description="Whether the block was successful")
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="fc258b94-a630-11ef-abc3-df050b75b816",
|
||||
description="This block blocks a specific user on Twitter.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterBlockUserBlock.Input,
|
||||
output_schema=TwitterBlockUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"block_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def block_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.block(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.block_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,510 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterUnfollowUserBlock(Block):
|
||||
"""
|
||||
Allows a user to unfollow another user specified by target user ID.
|
||||
The request succeeds with no action when the authenticated user sends a request to a user they're not following or have already unfollowed.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "users.write", "follows.write", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to unfollow",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the unfollow action was successful"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="37e386a4-a631-11ef-b7bd-b78204b35fa4",
|
||||
description="This block unfollows a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnfollowUserBlock.Input,
|
||||
output_schema=TwitterUnfollowUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unfollow_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unfollow_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unfollow_user(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unfollow_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterFollowUserBlock(Block):
|
||||
"""
|
||||
Allows a user to follow another user specified by target user ID. If the target user does not have public Tweets,
|
||||
this endpoint will send a follow request. The request succeeds with no action when the authenticated user sends a
|
||||
request to a user they're already following, or if they're sending a follower request to a user that does not have
|
||||
public Tweets.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "users.write", "follows.write", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to follow",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the follow action was successful"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="1aae6a5e-a631-11ef-a090-435900c6d429",
|
||||
description="This block follows a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterFollowUserBlock.Input,
|
||||
output_schema=TwitterFollowUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("success", True)],
|
||||
test_mock={"follow_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def follow_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.follow_user(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.follow_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetFollowersBlock(Block):
|
||||
"""
|
||||
Retrieves a list of followers for a specified Twitter user ID
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access", "follows.read"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID whose followers you would like to retrieve",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-1000, default 100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for retrieving next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ids: list[str] = SchemaField(description="List of follower user IDs")
|
||||
usernames: list[str] = SchemaField(description="List of follower usernames")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
data: list[dict] = SchemaField(description="Complete user data for followers")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="30f66410-a631-11ef-8fe7-d7f888b4f43c",
|
||||
description="This block retrieves followers of a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetFollowersBlock.Input,
|
||||
output_schema=TwitterGetFollowersBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"max_results": 1,
|
||||
"pagination_token": "",
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1234567890"]),
|
||||
("usernames", ["testuser"]),
|
||||
("data", [{"id": "1234567890", "username": "testuser"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_followers": lambda *args, **kwargs: (
|
||||
["1234567890"],
|
||||
["testuser"],
|
||||
[{"id": "1234567890", "username": "testuser"}],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_followers(
|
||||
credentials: TwitterCredentials,
|
||||
target_user_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": target_user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_users_followers(**params))
|
||||
|
||||
meta = {}
|
||||
follower_ids = []
|
||||
follower_usernames = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
follower_ids = [str(user.id) for user in response.data]
|
||||
follower_usernames = [user.username for user in response.data]
|
||||
|
||||
return (
|
||||
follower_ids,
|
||||
follower_usernames,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
raise Exception("Followers not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, usernames, data, includes, meta, next_token = self.get_followers(
|
||||
credentials,
|
||||
input_data.target_user_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if includes:
|
||||
yield "includes", includes
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetFollowingBlock(Block):
|
||||
"""
|
||||
Retrieves a list of users that a specified Twitter user ID is following
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access", "follows.read"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID whose following you would like to retrieve",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="Maximum number of results to return (1-1000, default 100)",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token for retrieving next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ids: list[str] = SchemaField(description="List of following user IDs")
|
||||
usernames: list[str] = SchemaField(description="List of following usernames")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
data: list[dict] = SchemaField(description="Complete user data for following")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="264a399c-a631-11ef-a97d-bfde4ca91173",
|
||||
description="This block retrieves the users that a specified Twitter user is following.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetFollowingBlock.Input,
|
||||
output_schema=TwitterGetFollowingBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"max_results": 1,
|
||||
"pagination_token": None,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["1234567890"]),
|
||||
("usernames", ["testuser"]),
|
||||
("data", [{"id": "1234567890", "username": "testuser"}]),
|
||||
],
|
||||
test_mock={
|
||||
"get_following": lambda *args, **kwargs: (
|
||||
["1234567890"],
|
||||
["testuser"],
|
||||
[{"id": "1234567890", "username": "testuser"}],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_following(
|
||||
credentials: TwitterCredentials,
|
||||
target_user_id: str,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": target_user_id,
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_users_following(**params))
|
||||
|
||||
meta = {}
|
||||
following_ids = []
|
||||
following_usernames = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
following_ids = [str(user.id) for user in response.data]
|
||||
following_usernames = [user.username for user in response.data]
|
||||
|
||||
return (
|
||||
following_ids,
|
||||
following_usernames,
|
||||
data,
|
||||
included,
|
||||
meta,
|
||||
next_token,
|
||||
)
|
||||
|
||||
raise Exception("Following not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, usernames, data, includes, meta, next_token = self.get_following(
|
||||
credentials,
|
||||
input_data.target_user_id,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if includes:
|
||||
yield "includes", includes
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,328 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
import tweepy
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class TwitterUnmuteUserBlock(Block):
|
||||
"""
|
||||
Allows a user to unmute another user specified by target user ID.
|
||||
The request succeeds with no action when the user sends a request to a user they're not muting or have already unmuted.
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "users.write", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to unmute",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the unmute action was successful"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="40458504-a631-11ef-940b-eff92be55422",
|
||||
description="This block unmutes a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterUnmuteUserBlock.Input,
|
||||
output_schema=TwitterUnmuteUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"unmute_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def unmute_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.unmute(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.unmute_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterGetMutedUsersBlock(Block):
|
||||
"""
|
||||
Returns a list of users who are muted by the authenticating user
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access"]
|
||||
)
|
||||
|
||||
max_results: int | None = SchemaField(
|
||||
description="The maximum number of results to be returned per page (1-1000). Default is 100.",
|
||||
placeholder="Enter max results",
|
||||
default=10,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
pagination_token: str | None = SchemaField(
|
||||
description="Token to request next/previous page of results",
|
||||
placeholder="Enter pagination token",
|
||||
default="",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
ids: list[str] = SchemaField(description="List of muted user IDs")
|
||||
usernames: list[str] = SchemaField(description="List of muted usernames")
|
||||
next_token: str = SchemaField(description="Next token for pagination")
|
||||
|
||||
data: list[dict] = SchemaField(description="Complete user data for muted users")
|
||||
includes: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
meta: dict = SchemaField(description="Metadata including pagination info")
|
||||
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="475024da-a631-11ef-9ccd-f724b8b03cda",
|
||||
description="This block gets a list of users muted by the authenticating user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetMutedUsersBlock.Input,
|
||||
output_schema=TwitterGetMutedUsersBlock.Output,
|
||||
test_input={
|
||||
"max_results": 2,
|
||||
"pagination_token": "",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["12345", "67890"]),
|
||||
("usernames", ["testuser1", "testuser2"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "12345", "username": "testuser1"},
|
||||
{"id": "67890", "username": "testuser2"},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_muted_users": lambda *args, **kwargs: (
|
||||
["12345", "67890"],
|
||||
["testuser1", "testuser2"],
|
||||
[
|
||||
{"id": "12345", "username": "testuser1"},
|
||||
{"id": "67890", "username": "testuser2"},
|
||||
],
|
||||
{},
|
||||
{},
|
||||
None,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_muted_users(
|
||||
credentials: TwitterCredentials,
|
||||
max_results: int | None,
|
||||
pagination_token: str | None,
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"max_results": max_results,
|
||||
"pagination_token": (
|
||||
None if pagination_token == "" else pagination_token
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_muted(**params))
|
||||
|
||||
meta = {}
|
||||
user_ids = []
|
||||
usernames = []
|
||||
next_token = None
|
||||
|
||||
if response.meta:
|
||||
meta = response.meta
|
||||
next_token = meta.get("next_token")
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
user_ids = [str(item.id) for item in response.data]
|
||||
usernames = [item.username for item in response.data]
|
||||
|
||||
return user_ids, usernames, data, included, meta, next_token
|
||||
|
||||
raise Exception("Muted users not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
ids, usernames, data, includes, meta, next_token = self.get_muted_users(
|
||||
credentials,
|
||||
input_data.max_results,
|
||||
input_data.pagination_token,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if usernames:
|
||||
yield "usernames", usernames
|
||||
if next_token:
|
||||
yield "next_token", next_token
|
||||
if data:
|
||||
yield "data", data
|
||||
if includes:
|
||||
yield "includes", includes
|
||||
if meta:
|
||||
yield "meta", meta
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class TwitterMuteUserBlock(Block):
|
||||
"""
|
||||
Allows a user to mute another user specified by target user ID
|
||||
"""
|
||||
|
||||
class Input(BlockSchema):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "users.write", "offline.access"]
|
||||
)
|
||||
|
||||
target_user_id: str = SchemaField(
|
||||
description="The user ID of the user that you would like to mute",
|
||||
placeholder="Enter target user ID",
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
success: bool = SchemaField(
|
||||
description="Whether the mute action was successful"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="4d1919d0-a631-11ef-90ab-3b73af9ce8f1",
|
||||
description="This block mutes a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterMuteUserBlock.Input,
|
||||
output_schema=TwitterMuteUserBlock.Output,
|
||||
test_input={
|
||||
"target_user_id": "12345",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("success", True),
|
||||
],
|
||||
test_mock={"mute_user": lambda *args, **kwargs: True},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mute_user(credentials: TwitterCredentials, target_user_id: str):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
client.mute(target_user_id=target_user_id, user_auth=False)
|
||||
|
||||
return True
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
success = self.mute_user(credentials, input_data.target_user_id)
|
||||
yield "success", success
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -1,383 +0,0 @@
|
||||
from typing import Literal, Union, cast
|
||||
|
||||
import tweepy
|
||||
from pydantic import BaseModel
|
||||
from tweepy.client import Response
|
||||
|
||||
from backend.blocks.twitter._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
TwitterCredentials,
|
||||
TwitterCredentialsField,
|
||||
TwitterCredentialsInput,
|
||||
)
|
||||
from backend.blocks.twitter._builders import UserExpansionsBuilder
|
||||
from backend.blocks.twitter._serializer import (
|
||||
IncludesSerializer,
|
||||
ResponseDataSerializer,
|
||||
)
|
||||
from backend.blocks.twitter._types import (
|
||||
TweetFieldsFilter,
|
||||
TweetUserFieldsFilter,
|
||||
UserExpansionInputs,
|
||||
UserExpansionsFilter,
|
||||
)
|
||||
from backend.blocks.twitter.tweepy_exceptions import handle_tweepy_exception
|
||||
from backend.data.block import Block, BlockCategory, BlockOutput, BlockSchema
|
||||
from backend.data.model import SchemaField
|
||||
|
||||
|
||||
class UserId(BaseModel):
|
||||
discriminator: Literal["user_id"]
|
||||
user_id: str = SchemaField(description="The ID of the user to lookup", default="")
|
||||
|
||||
|
||||
class Username(BaseModel):
|
||||
discriminator: Literal["username"]
|
||||
username: str = SchemaField(
|
||||
description="The Twitter username (handle) of the user", default=""
|
||||
)
|
||||
|
||||
|
||||
class TwitterGetUserBlock(Block):
|
||||
"""
|
||||
Gets information about a single Twitter user specified by ID or username
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access"]
|
||||
)
|
||||
|
||||
identifier: Union[UserId, Username] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Choose whether to identify the user by their unique Twitter ID or by their username",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
id: str = SchemaField(description="User ID")
|
||||
username_: str = SchemaField(description="User username")
|
||||
name_: str = SchemaField(description="User name")
|
||||
|
||||
# Complete outputs
|
||||
data: dict = SchemaField(description="Complete user data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5446db8e-a631-11ef-812a-cf315d373ee9",
|
||||
description="This block retrieves information about a specified Twitter user.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetUserBlock.Input,
|
||||
output_schema=TwitterGetUserBlock.Output,
|
||||
test_input={
|
||||
"identifier": {"discriminator": "username", "username": "twitter"},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("id", "783214"),
|
||||
("username_", "twitter"),
|
||||
("name_", "Twitter"),
|
||||
(
|
||||
"data",
|
||||
{
|
||||
"user": {
|
||||
"id": "783214",
|
||||
"username": "twitter",
|
||||
"name": "Twitter",
|
||||
}
|
||||
},
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_user": lambda *args, **kwargs: (
|
||||
{
|
||||
"user": {
|
||||
"id": "783214",
|
||||
"username": "twitter",
|
||||
"name": "Twitter",
|
||||
}
|
||||
},
|
||||
{},
|
||||
"twitter",
|
||||
"783214",
|
||||
"Twitter",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user(
|
||||
credentials: TwitterCredentials,
|
||||
identifier: Union[UserId, Username],
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"id": identifier.user_id if isinstance(identifier, UserId) else None,
|
||||
"username": (
|
||||
identifier.username if isinstance(identifier, Username) else None
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_user(**params))
|
||||
|
||||
username = ""
|
||||
id = ""
|
||||
name = ""
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_dict(response.data)
|
||||
|
||||
if response.data:
|
||||
username = response.data.username
|
||||
id = str(response.data.id)
|
||||
name = response.data.name
|
||||
|
||||
if username and id:
|
||||
return data, included, username, id, name
|
||||
else:
|
||||
raise tweepy.TweepyException("User not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, username, id, name = self.get_user(
|
||||
credentials,
|
||||
input_data.identifier,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if id:
|
||||
yield "id", id
|
||||
if username:
|
||||
yield "username_", username
|
||||
if name:
|
||||
yield "name_", name
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
|
||||
|
||||
class UserIdList(BaseModel):
|
||||
discriminator: Literal["user_id_list"]
|
||||
user_ids: list[str] = SchemaField(
|
||||
description="List of user IDs to lookup (max 100)",
|
||||
placeholder="Enter user IDs",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
class UsernameList(BaseModel):
|
||||
discriminator: Literal["username_list"]
|
||||
usernames: list[str] = SchemaField(
|
||||
description="List of Twitter usernames/handles to lookup (max 100)",
|
||||
placeholder="Enter usernames",
|
||||
default=[],
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
|
||||
class TwitterGetUsersBlock(Block):
|
||||
"""
|
||||
Gets information about multiple Twitter users specified by IDs or usernames
|
||||
"""
|
||||
|
||||
class Input(UserExpansionInputs):
|
||||
credentials: TwitterCredentialsInput = TwitterCredentialsField(
|
||||
["users.read", "offline.access"]
|
||||
)
|
||||
|
||||
identifier: Union[UserIdList, UsernameList] = SchemaField(
|
||||
discriminator="discriminator",
|
||||
description="Choose whether to identify users by their unique Twitter IDs or by their usernames",
|
||||
advanced=False,
|
||||
)
|
||||
|
||||
class Output(BlockSchema):
|
||||
# Common outputs
|
||||
ids: list[str] = SchemaField(description="User IDs")
|
||||
usernames_: list[str] = SchemaField(description="User usernames")
|
||||
names_: list[str] = SchemaField(description="User names")
|
||||
|
||||
# Complete outputs
|
||||
data: list[dict] = SchemaField(description="Complete users data")
|
||||
included: dict = SchemaField(
|
||||
description="Additional data requested via expansions"
|
||||
)
|
||||
error: str = SchemaField(description="Error message if the request failed")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="5abc857c-a631-11ef-8cfc-f7b79354f7a1",
|
||||
description="This block retrieves information about multiple Twitter users.",
|
||||
categories={BlockCategory.SOCIAL},
|
||||
input_schema=TwitterGetUsersBlock.Input,
|
||||
output_schema=TwitterGetUsersBlock.Output,
|
||||
test_input={
|
||||
"identifier": {
|
||||
"discriminator": "username_list",
|
||||
"usernames": ["twitter", "twitterdev"],
|
||||
},
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
"expansions": None,
|
||||
"tweet_fields": None,
|
||||
"user_fields": None,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[
|
||||
("ids", ["783214", "2244994945"]),
|
||||
("usernames_", ["twitter", "twitterdev"]),
|
||||
("names_", ["Twitter", "Twitter Dev"]),
|
||||
(
|
||||
"data",
|
||||
[
|
||||
{"id": "783214", "username": "twitter", "name": "Twitter"},
|
||||
{
|
||||
"id": "2244994945",
|
||||
"username": "twitterdev",
|
||||
"name": "Twitter Dev",
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
test_mock={
|
||||
"get_users": lambda *args, **kwargs: (
|
||||
[
|
||||
{"id": "783214", "username": "twitter", "name": "Twitter"},
|
||||
{
|
||||
"id": "2244994945",
|
||||
"username": "twitterdev",
|
||||
"name": "Twitter Dev",
|
||||
},
|
||||
],
|
||||
{},
|
||||
["twitter", "twitterdev"],
|
||||
["783214", "2244994945"],
|
||||
["Twitter", "Twitter Dev"],
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_users(
|
||||
credentials: TwitterCredentials,
|
||||
identifier: Union[UserIdList, UsernameList],
|
||||
expansions: UserExpansionsFilter | None,
|
||||
tweet_fields: TweetFieldsFilter | None,
|
||||
user_fields: TweetUserFieldsFilter | None,
|
||||
):
|
||||
try:
|
||||
client = tweepy.Client(
|
||||
bearer_token=credentials.access_token.get_secret_value()
|
||||
)
|
||||
|
||||
params = {
|
||||
"ids": (
|
||||
",".join(identifier.user_ids)
|
||||
if isinstance(identifier, UserIdList)
|
||||
else None
|
||||
),
|
||||
"usernames": (
|
||||
",".join(identifier.usernames)
|
||||
if isinstance(identifier, UsernameList)
|
||||
else None
|
||||
),
|
||||
"user_auth": False,
|
||||
}
|
||||
|
||||
params = (
|
||||
UserExpansionsBuilder(params)
|
||||
.add_expansions(expansions)
|
||||
.add_tweet_fields(tweet_fields)
|
||||
.add_user_fields(user_fields)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = cast(Response, client.get_users(**params))
|
||||
|
||||
usernames = []
|
||||
ids = []
|
||||
names = []
|
||||
|
||||
included = IncludesSerializer.serialize(response.includes)
|
||||
data = ResponseDataSerializer.serialize_list(response.data)
|
||||
|
||||
if response.data:
|
||||
for user in response.data:
|
||||
usernames.append(user.username)
|
||||
ids.append(str(user.id))
|
||||
names.append(user.name)
|
||||
|
||||
if usernames and ids:
|
||||
return data, included, usernames, ids, names
|
||||
else:
|
||||
raise tweepy.TweepyException("Users not found")
|
||||
|
||||
except tweepy.TweepyException:
|
||||
raise
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: TwitterCredentials,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
data, included, usernames, ids, names = self.get_users(
|
||||
credentials,
|
||||
input_data.identifier,
|
||||
input_data.expansions,
|
||||
input_data.tweet_fields,
|
||||
input_data.user_fields,
|
||||
)
|
||||
if ids:
|
||||
yield "ids", ids
|
||||
if usernames:
|
||||
yield "usernames_", usernames
|
||||
if names:
|
||||
yield "names_", names
|
||||
if data:
|
||||
yield "data", data
|
||||
if included:
|
||||
yield "included", included
|
||||
except Exception as e:
|
||||
yield "error", handle_tweepy_exception(e)
|
||||
@@ -22,10 +22,10 @@ from backend.util import json
|
||||
from backend.util.settings import Config
|
||||
|
||||
from .model import (
|
||||
CREDENTIALS_FIELD_NAME,
|
||||
ContributorDetails,
|
||||
Credentials,
|
||||
CredentialsMetaInput,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
|
||||
app_config = Config()
|
||||
@@ -42,7 +42,6 @@ class BlockType(Enum):
|
||||
OUTPUT = "Output"
|
||||
NOTE = "Note"
|
||||
WEBHOOK = "Webhook"
|
||||
WEBHOOK_MANUAL = "Webhook (manual)"
|
||||
AGENT = "Agent"
|
||||
|
||||
|
||||
@@ -58,12 +57,8 @@ class BlockCategory(Enum):
|
||||
COMMUNICATION = "Block that interacts with communication platforms."
|
||||
DEVELOPER_TOOLS = "Developer tools such as GitHub blocks."
|
||||
DATA = "Block that interacts with structured data."
|
||||
HARDWARE = "Block that interacts with hardware."
|
||||
AGENT = "Block that interacts with other agents."
|
||||
CRM = "Block that interacts with CRM services."
|
||||
SAFETY = (
|
||||
"Block that provides AI safety mechanisms such as detecting harmful content"
|
||||
)
|
||||
|
||||
def dict(self) -> dict[str, str]:
|
||||
return {"category": self.name, "description": self.value}
|
||||
@@ -95,11 +90,15 @@ class BlockSchema(BaseModel):
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [ref_to_dict(item) for item in obj]
|
||||
|
||||
return obj
|
||||
|
||||
cls.cached_jsonschema = cast(dict[str, Any], ref_to_dict(model))
|
||||
|
||||
# Set default properties values
|
||||
for field in cls.cached_jsonschema.get("properties", {}).values():
|
||||
if isinstance(field, dict) and "advanced" not in field:
|
||||
field["advanced"] = True
|
||||
|
||||
return cls.cached_jsonschema
|
||||
|
||||
@classmethod
|
||||
@@ -141,38 +140,17 @@ class BlockSchema(BaseModel):
|
||||
@classmethod
|
||||
def __pydantic_init_subclass__(cls, **kwargs):
|
||||
"""Validates the schema definition. Rules:
|
||||
- Fields with annotation `CredentialsMetaInput` MUST be
|
||||
named `credentials` or `*_credentials`
|
||||
- Fields named `credentials` or `*_credentials` MUST be
|
||||
of type `CredentialsMetaInput`
|
||||
- Only one `CredentialsMetaInput` field may be present.
|
||||
- This field MUST be called `credentials`.
|
||||
- A field that is called `credentials` MUST be a `CredentialsMetaInput`.
|
||||
"""
|
||||
super().__pydantic_init_subclass__(**kwargs)
|
||||
|
||||
# Reset cached JSON schema to prevent inheriting it from parent class
|
||||
cls.cached_jsonschema = {}
|
||||
|
||||
credentials_fields = cls.get_credentials_fields()
|
||||
|
||||
for field_name in cls.get_fields():
|
||||
if is_credentials_field_name(field_name):
|
||||
if field_name not in credentials_fields:
|
||||
raise TypeError(
|
||||
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
||||
f"is not of type {CredentialsMetaInput.__name__}"
|
||||
)
|
||||
|
||||
credentials_fields[field_name].validate_credentials_field_schema(cls)
|
||||
|
||||
elif field_name in credentials_fields:
|
||||
raise KeyError(
|
||||
f"Credentials field '{field_name}' on {cls.__qualname__} "
|
||||
"has invalid name: must be 'credentials' or *_credentials"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_credentials_fields(cls) -> dict[str, type[CredentialsMetaInput]]:
|
||||
return {
|
||||
field_name: info.annotation
|
||||
credentials_fields = [
|
||||
field_name
|
||||
for field_name, info in cls.model_fields.items()
|
||||
if (
|
||||
inspect.isclass(info.annotation)
|
||||
@@ -181,7 +159,32 @@ class BlockSchema(BaseModel):
|
||||
CredentialsMetaInput,
|
||||
)
|
||||
)
|
||||
}
|
||||
]
|
||||
if len(credentials_fields) > 1:
|
||||
raise ValueError(
|
||||
f"{cls.__qualname__} can only have one CredentialsMetaInput field"
|
||||
)
|
||||
elif (
|
||||
len(credentials_fields) == 1
|
||||
and credentials_fields[0] != CREDENTIALS_FIELD_NAME
|
||||
):
|
||||
raise ValueError(
|
||||
f"CredentialsMetaInput field on {cls.__qualname__} "
|
||||
"must be named 'credentials'"
|
||||
)
|
||||
elif (
|
||||
len(credentials_fields) == 0
|
||||
and CREDENTIALS_FIELD_NAME in cls.model_fields.keys()
|
||||
):
|
||||
raise TypeError(
|
||||
f"Field 'credentials' on {cls.__qualname__} "
|
||||
f"must be of type {CredentialsMetaInput.__name__}"
|
||||
)
|
||||
if credentials_field := cls.model_fields.get(CREDENTIALS_FIELD_NAME):
|
||||
credentials_input_type = cast(
|
||||
CredentialsMetaInput, credentials_field.annotation
|
||||
)
|
||||
credentials_input_type.validate_credentials_field_schema(cls)
|
||||
|
||||
|
||||
BlockSchemaInputType = TypeVar("BlockSchemaInputType", bound=BlockSchema)
|
||||
@@ -193,12 +196,7 @@ class EmptySchema(BlockSchema):
|
||||
|
||||
|
||||
# --8<-- [start:BlockWebhookConfig]
|
||||
class BlockManualWebhookConfig(BaseModel):
|
||||
"""
|
||||
Configuration model for webhook-triggered blocks on which
|
||||
the user has to manually set up the webhook at the provider.
|
||||
"""
|
||||
|
||||
class BlockWebhookConfig(BaseModel):
|
||||
provider: str
|
||||
"""The service provider that the webhook connects to"""
|
||||
|
||||
@@ -209,27 +207,6 @@ class BlockManualWebhookConfig(BaseModel):
|
||||
Only for use in the corresponding `WebhooksManager`.
|
||||
"""
|
||||
|
||||
event_filter_input: str = ""
|
||||
"""
|
||||
Name of the block's event filter input.
|
||||
Leave empty if the corresponding webhook doesn't have distinct event/payload types.
|
||||
"""
|
||||
|
||||
event_format: str = "{event}"
|
||||
"""
|
||||
Template string for the event(s) that a block instance subscribes to.
|
||||
Applied individually to each event selected in the event filter input.
|
||||
|
||||
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
|
||||
"""
|
||||
|
||||
|
||||
class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
"""
|
||||
Configuration model for webhook-triggered blocks for which
|
||||
the webhook can be automatically set up through the provider's API.
|
||||
"""
|
||||
|
||||
resource_format: str
|
||||
"""
|
||||
Template string for the resource that a block instance subscribes to.
|
||||
@@ -239,6 +216,17 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
|
||||
Only for use in the corresponding `WebhooksManager`.
|
||||
"""
|
||||
|
||||
event_filter_input: str
|
||||
"""Name of the block's event filter input."""
|
||||
|
||||
event_format: str = "{event}"
|
||||
"""
|
||||
Template string for the event(s) that a block instance subscribes to.
|
||||
Applied individually to each event selected in the event filter input.
|
||||
|
||||
Example: `"pull_request.{event}"` -> `"pull_request.opened"`
|
||||
"""
|
||||
# --8<-- [end:BlockWebhookConfig]
|
||||
|
||||
|
||||
@@ -254,11 +242,11 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
test_input: BlockInput | list[BlockInput] | None = None,
|
||||
test_output: BlockData | list[BlockData] | None = None,
|
||||
test_mock: dict[str, Any] | None = None,
|
||||
test_credentials: Optional[Credentials | dict[str, Credentials]] = None,
|
||||
test_credentials: Optional[Credentials] = None,
|
||||
disabled: bool = False,
|
||||
static_output: bool = False,
|
||||
block_type: BlockType = BlockType.STANDARD,
|
||||
webhook_config: Optional[BlockWebhookConfig | BlockManualWebhookConfig] = None,
|
||||
webhook_config: Optional[BlockWebhookConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the block with the given schema.
|
||||
@@ -289,44 +277,27 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.contributors = contributors or set()
|
||||
self.disabled = disabled
|
||||
self.static_output = static_output
|
||||
self.block_type = block_type
|
||||
self.block_type = block_type if not webhook_config else BlockType.WEBHOOK
|
||||
self.webhook_config = webhook_config
|
||||
self.execution_stats = {}
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
# Enforce presence of credentials field on auto-setup webhook blocks
|
||||
if not (cred_fields := self.input_schema.get_credentials_fields()):
|
||||
raise TypeError(
|
||||
"credentials field is required on auto-setup webhook blocks"
|
||||
)
|
||||
# Disallow multiple credentials inputs on webhook blocks
|
||||
elif len(cred_fields) > 1:
|
||||
raise ValueError(
|
||||
"Multiple credentials inputs not supported on webhook blocks"
|
||||
)
|
||||
|
||||
self.block_type = BlockType.WEBHOOK
|
||||
else:
|
||||
self.block_type = BlockType.WEBHOOK_MANUAL
|
||||
|
||||
# Enforce shape of webhook event filter, if present
|
||||
if self.webhook_config.event_filter_input:
|
||||
event_filter_field = self.input_schema.model_fields[
|
||||
self.webhook_config.event_filter_input
|
||||
]
|
||||
if not (
|
||||
isinstance(event_filter_field.annotation, type)
|
||||
and issubclass(event_filter_field.annotation, BaseModel)
|
||||
and all(
|
||||
field.annotation is bool
|
||||
for field in event_filter_field.annotation.model_fields.values()
|
||||
)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
f"{self.name} has an invalid webhook event selector: "
|
||||
"field must be a BaseModel and all its fields must be boolean"
|
||||
)
|
||||
# Enforce shape of webhook event filter
|
||||
event_filter_field = self.input_schema.model_fields[
|
||||
self.webhook_config.event_filter_input
|
||||
]
|
||||
if not (
|
||||
isinstance(event_filter_field.annotation, type)
|
||||
and issubclass(event_filter_field.annotation, BaseModel)
|
||||
and all(
|
||||
field.annotation is bool
|
||||
for field in event_filter_field.annotation.model_fields.values()
|
||||
)
|
||||
):
|
||||
raise NotImplementedError(
|
||||
f"{self.name} has an invalid webhook event selector: "
|
||||
"field must be a BaseModel and all its fields must be boolean"
|
||||
)
|
||||
|
||||
# Enforce presence of 'payload' input
|
||||
if "payload" not in self.input_schema.model_fields:
|
||||
|
||||
@@ -51,7 +51,6 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.LLAMA3_1_405B: 1,
|
||||
LlmModel.LLAMA3_1_70B: 1,
|
||||
LlmModel.LLAMA3_1_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_2: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_8B: 1,
|
||||
LlmModel.OLLAMA_LLAMA3_405B: 1,
|
||||
LlmModel.OLLAMA_DOLPHIN: 1,
|
||||
|
||||
@@ -270,9 +270,9 @@ async def update_graph_execution_start_time(graph_exec_id: str):
|
||||
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
stats: dict[str, Any],
|
||||
) -> ExecutionResult:
|
||||
status = ExecutionStatus.FAILED if stats.get("error") else ExecutionStatus.COMPLETED
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
|
||||
@@ -84,8 +84,6 @@ class NodeModel(Node):
|
||||
raise ValueError(f"Block #{self.block_id} not found for node #{self.id}")
|
||||
if not block.webhook_config:
|
||||
raise TypeError("This method can't be used on non-webhook blocks")
|
||||
if not block.webhook_config.event_filter_input:
|
||||
return True
|
||||
event_filter = self.input_default.get(block.webhook_config.event_filter_input)
|
||||
if not event_filter:
|
||||
raise ValueError(f"Event filter is not configured on node #{self.id}")
|
||||
@@ -193,8 +191,7 @@ class Graph(BaseDbModel):
|
||||
"properties": {
|
||||
p.name: {
|
||||
"secret": p.secret,
|
||||
# Default value has to be set for advanced fields.
|
||||
"advanced": p.advanced and p.value is not None,
|
||||
"advanced": p.advanced,
|
||||
"title": p.title or p.name,
|
||||
**({"description": p.description} if p.description else {}),
|
||||
**({"default": p.value} if p.value is not None else {}),
|
||||
@@ -260,7 +257,7 @@ class GraphModel(Graph):
|
||||
for link in self.links:
|
||||
input_links[link.sink_id].append(link)
|
||||
|
||||
# Nodes: required fields are filled or connected and dependencies are satisfied
|
||||
# Nodes: required fields are filled or connected
|
||||
for node in self.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if block is None:
|
||||
@@ -271,55 +268,16 @@ class GraphModel(Graph):
|
||||
+ [sanitize(link.sink_name) for link in input_links.get(node.id, [])]
|
||||
)
|
||||
for name in block.input_schema.get_required_fields():
|
||||
if (
|
||||
name not in provided_inputs
|
||||
and not (
|
||||
name == "payload"
|
||||
and block.block_type
|
||||
in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
)
|
||||
and (
|
||||
for_run # Skip input completion validation, unless when executing.
|
||||
or block.block_type == BlockType.INPUT
|
||||
or block.block_type == BlockType.OUTPUT
|
||||
or block.block_type == BlockType.AGENT
|
||||
)
|
||||
if name not in provided_inputs and (
|
||||
for_run # Skip input completion validation, unless when executing.
|
||||
or block.block_type == BlockType.INPUT
|
||||
or block.block_type == BlockType.OUTPUT
|
||||
or block.block_type == BlockType.AGENT
|
||||
):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} required input missing: `{name}`"
|
||||
)
|
||||
|
||||
# Get input schema properties and check dependencies
|
||||
input_schema = block.input_schema.model_fields
|
||||
required_fields = block.input_schema.get_required_fields()
|
||||
|
||||
def has_value(name):
|
||||
return (
|
||||
node is not None
|
||||
and name in node.input_default
|
||||
and node.input_default[name] is not None
|
||||
and str(node.input_default[name]).strip() != ""
|
||||
) or (name in input_schema and input_schema[name].default is not None)
|
||||
|
||||
# Validate dependencies between fields
|
||||
for field_name, field_info in input_schema.items():
|
||||
# Apply input dependency validation only on run & field with depends_on
|
||||
json_schema_extra = field_info.json_schema_extra or {}
|
||||
dependencies = json_schema_extra.get("depends_on", [])
|
||||
if not for_run or not dependencies:
|
||||
continue
|
||||
|
||||
# Check if dependent field has value in input_default
|
||||
field_has_value = has_value(field_name)
|
||||
field_is_required = field_name in required_fields
|
||||
|
||||
# Check for missing dependencies when dependent field is present
|
||||
missing_deps = [dep for dep in dependencies if not has_value(dep)]
|
||||
if missing_deps and (field_has_value or field_is_required):
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id}: Field `{field_name}` requires [{', '.join(missing_deps)}] to be set"
|
||||
)
|
||||
|
||||
node_map = {v.id: v for v in self.nodes}
|
||||
|
||||
def is_static_output_block(nid: str) -> bool:
|
||||
@@ -369,7 +327,7 @@ class GraphModel(Graph):
|
||||
link.is_static = True # Each value block output should be static.
|
||||
|
||||
@staticmethod
|
||||
def from_db(graph: AgentGraph, for_export: bool = False):
|
||||
def from_db(graph: AgentGraph, hide_credentials: bool = False):
|
||||
return GraphModel(
|
||||
id=graph.id,
|
||||
user_id=graph.userId,
|
||||
@@ -379,7 +337,7 @@ class GraphModel(Graph):
|
||||
name=graph.name or "",
|
||||
description=graph.description or "",
|
||||
nodes=[
|
||||
NodeModel.from_db(GraphModel._process_node(node, for_export))
|
||||
GraphModel._process_node(node, hide_credentials)
|
||||
for node in graph.AgentNodes or []
|
||||
],
|
||||
links=list(
|
||||
@@ -392,29 +350,23 @@ class GraphModel(Graph):
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_node(node: AgentNode, for_export: bool) -> AgentNode:
|
||||
if for_export:
|
||||
# Remove credentials from node input
|
||||
if node.constantInput:
|
||||
constant_input = json.loads(
|
||||
node.constantInput, target_type=dict[str, Any]
|
||||
)
|
||||
constant_input = GraphModel._hide_node_input_credentials(constant_input)
|
||||
node.constantInput = json.dumps(constant_input)
|
||||
|
||||
# Remove webhook info
|
||||
node.webhookId = None
|
||||
node.Webhook = None
|
||||
|
||||
return node
|
||||
def _process_node(node: AgentNode, hide_credentials: bool) -> NodeModel:
|
||||
node_dict = {field: getattr(node, field) for field in node.model_fields}
|
||||
if hide_credentials and "constantInput" in node_dict:
|
||||
constant_input = json.loads(
|
||||
node_dict["constantInput"], target_type=dict[str, Any]
|
||||
)
|
||||
constant_input = GraphModel._hide_credentials_in_input(constant_input)
|
||||
node_dict["constantInput"] = json.dumps(constant_input)
|
||||
return NodeModel.from_db(AgentNode(**node_dict))
|
||||
|
||||
@staticmethod
|
||||
def _hide_node_input_credentials(input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
def _hide_credentials_in_input(input_data: dict[str, Any]) -> dict[str, Any]:
|
||||
sensitive_keys = ["credentials", "api_key", "password", "token", "secret"]
|
||||
result = {}
|
||||
for key, value in input_data.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = GraphModel._hide_node_input_credentials(value)
|
||||
result[key] = GraphModel._hide_credentials_in_input(value)
|
||||
elif isinstance(value, str) and any(
|
||||
sensitive_key in key.lower() for sensitive_key in sensitive_keys
|
||||
):
|
||||
@@ -424,26 +376,6 @@ class GraphModel(Graph):
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
def clean_graph(self):
|
||||
blocks = [block() for block in get_blocks().values()]
|
||||
|
||||
input_blocks = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if next(
|
||||
(
|
||||
b
|
||||
for b in blocks
|
||||
if b.id == node.block_id and b.block_type == BlockType.INPUT
|
||||
),
|
||||
None,
|
||||
)
|
||||
]
|
||||
|
||||
for node in self.nodes:
|
||||
if any(input_block.id == node.id for input_block in input_blocks):
|
||||
node.input_default["value"] = ""
|
||||
|
||||
|
||||
# --------------------- CRUD functions --------------------- #
|
||||
|
||||
@@ -500,15 +432,7 @@ async def get_graphs(
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
graph_models = []
|
||||
for graph in graphs:
|
||||
try:
|
||||
graph_models.append(GraphModel.from_db(graph))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing graph {graph.id}: {e}")
|
||||
continue
|
||||
|
||||
return graph_models
|
||||
return [GraphModel.from_db(graph) for graph in graphs]
|
||||
|
||||
|
||||
async def get_executions(user_id: str) -> list[GraphExecution]:
|
||||
@@ -531,7 +455,7 @@ async def get_graph(
|
||||
version: int | None = None,
|
||||
template: bool = False,
|
||||
user_id: str | None = None,
|
||||
for_export: bool = False,
|
||||
hide_credentials: bool = False,
|
||||
) -> GraphModel | None:
|
||||
"""
|
||||
Retrieves a graph from the DB.
|
||||
@@ -542,13 +466,13 @@ async def get_graph(
|
||||
"""
|
||||
where_clause: AgentGraphWhereInput = {
|
||||
"id": graph_id,
|
||||
"isTemplate": template,
|
||||
}
|
||||
if version is not None:
|
||||
where_clause["version"] = version
|
||||
elif not template:
|
||||
where_clause["isActive"] = True
|
||||
|
||||
# TODO: Fix hack workaround to get adding store agents to work
|
||||
if user_id is not None and not template:
|
||||
where_clause["userId"] = user_id
|
||||
|
||||
@@ -557,7 +481,7 @@ async def get_graph(
|
||||
include=AGENT_GRAPH_INCLUDE,
|
||||
order={"version": "desc"},
|
||||
)
|
||||
return GraphModel.from_db(graph, for_export) if graph else None
|
||||
return GraphModel.from_db(graph, hide_credentials) if graph else None
|
||||
|
||||
|
||||
async def set_graph_active_version(graph_id: str, version: int, user_id: str) -> None:
|
||||
@@ -629,20 +553,25 @@ async def __create_graph(tx, graph: Graph, user_id: str):
|
||||
"isTemplate": graph.is_template,
|
||||
"isActive": graph.is_active,
|
||||
"userId": user_id,
|
||||
"AgentNodes": {
|
||||
"create": [
|
||||
{
|
||||
"id": node.id,
|
||||
"agentBlockId": node.block_id,
|
||||
"constantInput": json.dumps(node.input_default),
|
||||
"metadata": json.dumps(node.metadata),
|
||||
}
|
||||
for node in graph.nodes
|
||||
]
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNode.prisma(tx).create(
|
||||
{
|
||||
"id": node.id,
|
||||
"agentBlockId": node.block_id,
|
||||
"agentGraphId": graph.id,
|
||||
"agentGraphVersion": graph.version,
|
||||
"constantInput": json.dumps(node.input_default),
|
||||
"metadata": json.dumps(node.metadata),
|
||||
}
|
||||
)
|
||||
for node in graph.nodes
|
||||
]
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
*[
|
||||
AgentNodeLink.prisma(tx).create(
|
||||
|
||||
@@ -3,12 +3,11 @@ from typing import TYPE_CHECKING, AsyncGenerator, Optional
|
||||
|
||||
from prisma import Json
|
||||
from prisma.models import IntegrationWebhook
|
||||
from pydantic import Field, computed_field
|
||||
from pydantic import Field
|
||||
|
||||
from backend.data.includes import INTEGRATION_WEBHOOK_INCLUDE
|
||||
from backend.data.queue import AsyncRedisEventBus
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
|
||||
from .db import BaseDbModel
|
||||
|
||||
@@ -32,11 +31,6 @@ class Webhook(BaseDbModel):
|
||||
|
||||
attached_nodes: Optional[list["NodeModel"]] = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return webhook_ingress_url(self.provider, self.id)
|
||||
|
||||
@staticmethod
|
||||
def from_db(webhook: IntegrationWebhook):
|
||||
from .graph import NodeModel
|
||||
@@ -90,10 +84,8 @@ async def get_webhook(webhook_id: str) -> Webhook:
|
||||
return Webhook.from_db(webhook)
|
||||
|
||||
|
||||
async def get_all_webhooks_by_creds(credentials_id: str) -> list[Webhook]:
|
||||
async def get_all_webhooks(credentials_id: str) -> list[Webhook]:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
if not credentials_id:
|
||||
raise ValueError("credentials_id must not be empty")
|
||||
webhooks = await IntegrationWebhook.prisma().find_many(
|
||||
where={"credentialsId": credentials_id},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
@@ -101,7 +93,7 @@ async def get_all_webhooks_by_creds(credentials_id: str) -> list[Webhook]:
|
||||
return [Webhook.from_db(webhook) for webhook in webhooks]
|
||||
|
||||
|
||||
async def find_webhook_by_credentials_and_props(
|
||||
async def find_webhook(
|
||||
credentials_id: str, webhook_type: str, resource: str, events: list[str]
|
||||
) -> Webhook | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
@@ -117,22 +109,6 @@ async def find_webhook_by_credentials_and_props(
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
async def find_webhook_by_graph_and_props(
|
||||
graph_id: str, provider: str, webhook_type: str, events: list[str]
|
||||
) -> Webhook | None:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
webhook = await IntegrationWebhook.prisma().find_first(
|
||||
where={
|
||||
"provider": provider,
|
||||
"webhookType": webhook_type,
|
||||
"events": {"has_every": events},
|
||||
"AgentNodes": {"some": {"agentGraphId": graph_id}},
|
||||
},
|
||||
include=INTEGRATION_WEBHOOK_INCLUDE,
|
||||
)
|
||||
return Webhook.from_db(webhook) if webhook else None
|
||||
|
||||
|
||||
async def update_webhook_config(webhook_id: str, updated_config: dict) -> Webhook:
|
||||
"""⚠️ No `user_id` check: DO NOT USE without check in user-facing endpoints."""
|
||||
_updated_webhook = await IntegrationWebhook.prisma().update(
|
||||
|
||||
@@ -134,20 +134,12 @@ def SchemaField(
|
||||
title: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
placeholder: Optional[str] = None,
|
||||
advanced: Optional[bool] = None,
|
||||
advanced: Optional[bool] = False,
|
||||
secret: bool = False,
|
||||
exclude: bool = False,
|
||||
hidden: Optional[bool] = None,
|
||||
depends_on: list[str] | None = None,
|
||||
image_upload: Optional[bool] = None,
|
||||
image_output: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> T:
|
||||
if default is PydanticUndefined and default_factory is None:
|
||||
advanced = False
|
||||
elif advanced is None:
|
||||
advanced = True
|
||||
|
||||
json_extra = {
|
||||
k: v
|
||||
for k, v in {
|
||||
@@ -155,9 +147,6 @@ def SchemaField(
|
||||
"secret": secret,
|
||||
"advanced": advanced,
|
||||
"hidden": hidden,
|
||||
"depends_on": depends_on,
|
||||
"image_upload": image_upload,
|
||||
"image_output": image_output,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
@@ -226,7 +215,6 @@ class OAuthState(BaseModel):
|
||||
token: str
|
||||
provider: str
|
||||
expires_at: int
|
||||
code_verifier: Optional[str] = None
|
||||
"""Unix timestamp (seconds) indicating when this OAuth state expires"""
|
||||
scopes: list[str]
|
||||
|
||||
@@ -250,8 +238,7 @@ CP = TypeVar("CP", bound=ProviderName)
|
||||
CT = TypeVar("CT", bound=CredentialsType)
|
||||
|
||||
|
||||
def is_credentials_field_name(field_name: str) -> bool:
|
||||
return field_name == "credentials" or field_name.endswith("_credentials")
|
||||
CREDENTIALS_FIELD_NAME = "credentials"
|
||||
|
||||
|
||||
class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
@@ -260,21 +247,21 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
provider: CP
|
||||
type: CT
|
||||
|
||||
@classmethod
|
||||
def allowed_providers(cls) -> tuple[ProviderName, ...]:
|
||||
return get_args(cls.model_fields["provider"].annotation)
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
|
||||
schema["credentials_provider"] = get_args(
|
||||
cls.model_fields["provider"].annotation
|
||||
)
|
||||
schema["credentials_types"] = get_args(cls.model_fields["type"].annotation)
|
||||
|
||||
@classmethod
|
||||
def allowed_cred_types(cls) -> tuple[CredentialsType, ...]:
|
||||
return get_args(cls.model_fields["type"].annotation)
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_credentials_field_schema(cls, model: type["BlockSchema"]):
|
||||
"""Validates the schema of a credentials input field"""
|
||||
field_name = next(
|
||||
name for name, type in model.get_credentials_fields().items() if type is cls
|
||||
)
|
||||
field_schema = model.jsonschema()["properties"][field_name]
|
||||
"""Validates the schema of a `credentials` field"""
|
||||
field_schema = model.jsonschema()["properties"][CREDENTIALS_FIELD_NAME]
|
||||
try:
|
||||
schema_extra = _CredentialsFieldSchemaExtra[CP, CT].model_validate(
|
||||
field_schema
|
||||
@@ -288,20 +275,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]):
|
||||
f"{field_schema}"
|
||||
) from e
|
||||
|
||||
if len(cls.allowed_providers()) > 1 and not schema_extra.discriminator:
|
||||
raise TypeError(
|
||||
f"Multi-provider CredentialsField '{field_name}' "
|
||||
"requires discriminator!"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _add_json_schema_extra(schema, cls: CredentialsMetaInput):
|
||||
schema["credentials_provider"] = cls.allowed_providers()
|
||||
schema["credentials_types"] = cls.allowed_cred_types()
|
||||
|
||||
model_config = ConfigDict(
|
||||
json_schema_extra=_add_json_schema_extra, # type: ignore
|
||||
)
|
||||
if (
|
||||
len(schema_extra.credentials_provider) > 1
|
||||
and not schema_extra.discriminator
|
||||
):
|
||||
raise TypeError("Multi-provider CredentialsField requires discriminator!")
|
||||
|
||||
|
||||
class _CredentialsFieldSchemaExtra(BaseModel, Generic[CP, CT]):
|
||||
|
||||
@@ -10,6 +10,7 @@ from contextlib import contextmanager
|
||||
from multiprocessing.pool import AsyncResult, Pool
|
||||
from typing import TYPE_CHECKING, Any, Generator, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -19,14 +20,7 @@ from autogpt_libs.utils.cache import thread_cached
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.data import redis
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockData,
|
||||
BlockInput,
|
||||
BlockSchema,
|
||||
BlockType,
|
||||
get_block,
|
||||
)
|
||||
from backend.data.block import Block, BlockData, BlockInput, BlockType, get_block
|
||||
from backend.data.execution import (
|
||||
ExecutionQueue,
|
||||
ExecutionResult,
|
||||
@@ -37,6 +31,7 @@ from backend.data.execution import (
|
||||
parse_execution_output,
|
||||
)
|
||||
from backend.data.graph import GraphModel, Link, Node
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util import json
|
||||
from backend.util.decorator import error_logged, time_measured
|
||||
@@ -175,11 +170,10 @@ def execute_node(
|
||||
# one (running) block at a time; simultaneous execution of blocks using same
|
||||
# credentials is not supported.
|
||||
creds_lock = None
|
||||
input_model = cast(type[BlockSchema], node_block.input_schema)
|
||||
for field_name, input_type in input_model.get_credentials_fields().items():
|
||||
credentials_meta = input_type(**input_data[field_name])
|
||||
if CREDENTIALS_FIELD_NAME in input_data:
|
||||
credentials_meta = CredentialsMetaInput(**input_data[CREDENTIALS_FIELD_NAME])
|
||||
credentials, creds_lock = creds_manager.acquire(user_id, credentials_meta.id)
|
||||
extra_exec_kwargs[field_name] = credentials
|
||||
extra_exec_kwargs["credentials"] = credentials
|
||||
|
||||
output_size = 0
|
||||
end_status = ExecutionStatus.COMPLETED
|
||||
@@ -597,7 +591,7 @@ class Executor:
|
||||
node_eid="*",
|
||||
block_name="-",
|
||||
)
|
||||
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
|
||||
timing_info, (exec_stats, error) = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
exec_stats["walltime"] = timing_info.wall_time
|
||||
@@ -605,7 +599,6 @@ class Executor:
|
||||
exec_stats["error"] = str(error) if error else None
|
||||
result = cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=status,
|
||||
stats=exec_stats,
|
||||
)
|
||||
cls.db_client.send_execution_update(result)
|
||||
@@ -617,12 +610,11 @@ class Executor:
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
) -> tuple[dict[str, Any], ExecutionStatus, Exception | None]:
|
||||
) -> tuple[dict[str, Any], Exception | None]:
|
||||
"""
|
||||
Returns:
|
||||
dict: The execution statistics of the graph execution.
|
||||
ExecutionStatus: The final status of the graph execution.
|
||||
Exception | None: The error that occurred during the execution, if any.
|
||||
The execution statistics of the graph execution.
|
||||
The error that occurred during the execution.
|
||||
"""
|
||||
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
|
||||
exec_stats = {
|
||||
@@ -667,7 +659,8 @@ class Executor:
|
||||
|
||||
while not queue.empty():
|
||||
if cancel.is_set():
|
||||
return exec_stats, ExecutionStatus.TERMINATED, error
|
||||
error = RuntimeError("Execution is cancelled")
|
||||
return exec_stats, error
|
||||
|
||||
exec_data = queue.get()
|
||||
|
||||
@@ -697,7 +690,8 @@ class Executor:
|
||||
)
|
||||
for node_id, execution in list(running_executions.items()):
|
||||
if cancel.is_set():
|
||||
return exec_stats, ExecutionStatus.TERMINATED, error
|
||||
error = RuntimeError("Execution is cancelled")
|
||||
return exec_stats, error
|
||||
|
||||
if not queue.empty():
|
||||
break # yield to parent loop to execute new queue items
|
||||
@@ -716,12 +710,7 @@ class Executor:
|
||||
finished = True
|
||||
cancel.set()
|
||||
cancel_thread.join()
|
||||
|
||||
return (
|
||||
exec_stats,
|
||||
ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED,
|
||||
error,
|
||||
)
|
||||
return exec_stats, error
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
@@ -809,13 +798,10 @@ class ExecutionManager(AppService):
|
||||
# Extract webhook payload, and assign it to the input pin
|
||||
webhook_payload_key = f"webhook_{node.webhook_id}_payload"
|
||||
if (
|
||||
block.block_type in (BlockType.WEBHOOK, BlockType.WEBHOOK_MANUAL)
|
||||
block.block_type == BlockType.WEBHOOK
|
||||
and node.webhook_id
|
||||
and webhook_payload_key in data
|
||||
):
|
||||
if webhook_payload_key not in data:
|
||||
raise ValueError(
|
||||
f"Node {block.name} #{node.id} webhook payload is missing"
|
||||
)
|
||||
input_data = {"payload": data[webhook_payload_key]}
|
||||
|
||||
input_data, error = validate_exec(node, input_data)
|
||||
@@ -887,8 +873,11 @@ class ExecutionManager(AppService):
|
||||
ExecutionStatus.COMPLETED,
|
||||
ExecutionStatus.FAILED,
|
||||
):
|
||||
self.db_client.upsert_execution_output(
|
||||
node_exec.node_exec_id, "error", "TERMINATED"
|
||||
)
|
||||
exec_update = self.db_client.update_execution_status(
|
||||
node_exec.node_exec_id, ExecutionStatus.TERMINATED
|
||||
node_exec.node_exec_id, ExecutionStatus.FAILED
|
||||
)
|
||||
self.db_client.send_execution_update(exec_update)
|
||||
|
||||
@@ -901,39 +890,41 @@ class ExecutionManager(AppService):
|
||||
raise ValueError(f"Unknown block {node.block_id} for node #{node.id}")
|
||||
|
||||
# Find any fields of type CredentialsMetaInput
|
||||
credentials_fields = cast(
|
||||
type[BlockSchema], block.input_schema
|
||||
).get_credentials_fields()
|
||||
if not credentials_fields:
|
||||
model_fields = cast(type[BaseModel], block.input_schema).model_fields
|
||||
if CREDENTIALS_FIELD_NAME not in model_fields:
|
||||
continue
|
||||
|
||||
for field_name, credentials_meta_type in credentials_fields.items():
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[field_name]
|
||||
field = model_fields[CREDENTIALS_FIELD_NAME]
|
||||
|
||||
# The BlockSchema class enforces that a `credentials` field is always a
|
||||
# `CredentialsMetaInput`, so we can safely assume this here.
|
||||
credentials_meta_type = cast(CredentialsMetaInput, field.annotation)
|
||||
credentials_meta = credentials_meta_type.model_validate(
|
||||
node.input_default[CREDENTIALS_FIELD_NAME]
|
||||
)
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = self.credentials_store.get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id}"
|
||||
)
|
||||
# Fetch the corresponding Credentials and perform sanity checks
|
||||
credentials = self.credentials_store.get_creds_by_id(
|
||||
user_id, credentials_meta.id
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
if not credentials:
|
||||
raise ValueError(
|
||||
f"Unknown credentials #{credentials_meta.id} "
|
||||
f"for node #{node.id} input '{field_name}'"
|
||||
)
|
||||
if (
|
||||
credentials.provider != credentials_meta.provider
|
||||
or credentials.type != credentials_meta.type
|
||||
):
|
||||
logger.warning(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch: "
|
||||
f"{credentials_meta.type}<>{credentials.type};"
|
||||
f"{credentials_meta.provider}<>{credentials.provider}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid credentials #{credentials.id} for node #{node.id}: "
|
||||
"type/provider mismatch"
|
||||
)
|
||||
|
||||
|
||||
# ------- UTILITIES ------- #
|
||||
@@ -953,8 +944,7 @@ def synchronized(key: str, timeout: int = 60):
|
||||
lock.acquire()
|
||||
yield
|
||||
finally:
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
lock.release()
|
||||
|
||||
|
||||
def llprint(message: str):
|
||||
|
||||
@@ -99,10 +99,6 @@ class ExecutionScheduler(AppService):
|
||||
def get_port(cls) -> int:
|
||||
return config.execution_scheduler_port
|
||||
|
||||
@classmethod
|
||||
def db_pool_size(cls) -> int:
|
||||
return config.scheduler_db_pool_size
|
||||
|
||||
@property
|
||||
@thread_cached
|
||||
def execution_client(self) -> ExecutionManager:
|
||||
@@ -114,11 +110,7 @@ class ExecutionScheduler(AppService):
|
||||
self.scheduler = BlockingScheduler(
|
||||
jobstores={
|
||||
"default": SQLAlchemyJobStore(
|
||||
engine=create_engine(
|
||||
url=db_url,
|
||||
pool_size=self.db_pool_size(),
|
||||
max_overflow=0,
|
||||
),
|
||||
engine=create_engine(db_url),
|
||||
metadata=MetaData(schema=db_schema),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
@@ -93,34 +91,6 @@ open_router_credentials = APIKeyCredentials(
|
||||
title="Use Credits for Open Router",
|
||||
expires_at=None,
|
||||
)
|
||||
fal_credentials = APIKeyCredentials(
|
||||
id="6c0f5bd0-9008-4638-9d79-4b40b631803e",
|
||||
provider="fal",
|
||||
api_key=SecretStr(settings.secrets.fal_api_key),
|
||||
title="Use Credits for FAL",
|
||||
expires_at=None,
|
||||
)
|
||||
exa_credentials = APIKeyCredentials(
|
||||
id="96153e04-9c6c-4486-895f-5bb683b1ecec",
|
||||
provider="exa",
|
||||
api_key=SecretStr(settings.secrets.exa_api_key),
|
||||
title="Use Credits for Exa search",
|
||||
expires_at=None,
|
||||
)
|
||||
e2b_credentials = APIKeyCredentials(
|
||||
id="78d19fd7-4d59-4a16-8277-3ce310acf2b7",
|
||||
provider="e2b",
|
||||
api_key=SecretStr(settings.secrets.e2b_api_key),
|
||||
title="Use Credits for E2B",
|
||||
expires_at=None,
|
||||
)
|
||||
nvidia_credentials = APIKeyCredentials(
|
||||
id="96b83908-2789-4dec-9968-18f0ece4ceb3",
|
||||
provider="nvidia",
|
||||
api_key=SecretStr(settings.secrets.nvidia_api_key),
|
||||
title="Use Credits for Nvidia",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_CREDENTIALS = [
|
||||
@@ -134,10 +104,6 @@ DEFAULT_CREDENTIALS = [
|
||||
jina_credentials,
|
||||
unreal_credentials,
|
||||
open_router_credentials,
|
||||
fal_credentials,
|
||||
exa_credentials,
|
||||
e2b_credentials,
|
||||
nvidia_credentials,
|
||||
]
|
||||
|
||||
|
||||
@@ -189,14 +155,6 @@ class IntegrationCredentialsStore:
|
||||
all_credentials.append(unreal_credentials)
|
||||
if settings.secrets.open_router_api_key:
|
||||
all_credentials.append(open_router_credentials)
|
||||
if settings.secrets.fal_api_key:
|
||||
all_credentials.append(fal_credentials)
|
||||
if settings.secrets.exa_api_key:
|
||||
all_credentials.append(exa_credentials)
|
||||
if settings.secrets.e2b_api_key:
|
||||
all_credentials.append(e2b_credentials)
|
||||
if settings.secrets.nvidia_api_key:
|
||||
all_credentials.append(nvidia_credentials)
|
||||
return all_credentials
|
||||
|
||||
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
|
||||
@@ -252,24 +210,18 @@ class IntegrationCredentialsStore:
|
||||
]
|
||||
self._set_user_integration_creds(user_id, filtered_credentials)
|
||||
|
||||
def store_state_token(
|
||||
self, user_id: str, provider: str, scopes: list[str], use_pkce: bool = False
|
||||
) -> tuple[str, str]:
|
||||
def store_state_token(self, user_id: str, provider: str, scopes: list[str]) -> str:
|
||||
token = secrets.token_urlsafe(32)
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)
|
||||
|
||||
(code_challenge, code_verifier) = self._generate_code_challenge()
|
||||
|
||||
state = OAuthState(
|
||||
token=token,
|
||||
provider=provider,
|
||||
code_verifier=code_verifier,
|
||||
expires_at=int(expires_at.timestamp()),
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
with self.locked_user_integrations(user_id):
|
||||
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
oauth_states.append(state)
|
||||
@@ -279,21 +231,39 @@ class IntegrationCredentialsStore:
|
||||
user_id=user_id, data=user_integrations
|
||||
)
|
||||
|
||||
return token, code_challenge
|
||||
return token
|
||||
|
||||
def _generate_code_challenge(self) -> tuple[str, str]:
|
||||
"""
|
||||
Generate code challenge using SHA256 from the code verifier.
|
||||
Currently only SHA256 is supported.(In future if we want to support more methods we can add them here)
|
||||
"""
|
||||
code_verifier = secrets.token_urlsafe(128)
|
||||
sha256_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(sha256_hash).decode("utf-8")
|
||||
return code_challenge.replace("=", ""), code_verifier
|
||||
|
||||
def verify_state_token(
|
||||
def get_any_valid_scopes_from_state_token(
|
||||
self, user_id: str, token: str, provider: str
|
||||
) -> Optional[OAuthState]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get the valid scopes from the OAuth state token. This will return any valid scopes
|
||||
from any OAuth state token for the given provider. If no valid scopes are found,
|
||||
an empty list is returned. DO NOT RELY ON THIS TOKEN TO AUTHENTICATE A USER, AS IT
|
||||
IS TO CHECK IF THE USER HAS GIVEN PERMISSIONS TO THE APPLICATION BEFORE EXCHANGING
|
||||
THE CODE FOR TOKENS.
|
||||
"""
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_state = next(
|
||||
(
|
||||
state
|
||||
for state in oauth_states
|
||||
if state.token == token
|
||||
and state.provider == provider
|
||||
and state.expires_at > now.timestamp()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if valid_state:
|
||||
return valid_state.scopes
|
||||
|
||||
return []
|
||||
|
||||
def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
|
||||
with self.locked_user_integrations(user_id):
|
||||
user_integrations = self._get_user_integrations(user_id)
|
||||
oauth_states = user_integrations.oauth_states
|
||||
@@ -315,9 +285,9 @@ class IntegrationCredentialsStore:
|
||||
oauth_states.remove(valid_state)
|
||||
user_integrations.oauth_states = oauth_states
|
||||
self.db_manager.update_user_integrations(user_id, user_integrations)
|
||||
return valid_state
|
||||
return True
|
||||
|
||||
return None
|
||||
return False
|
||||
|
||||
def _set_user_integration_creds(
|
||||
self, user_id: str, credentials: list[Credentials]
|
||||
|
||||
@@ -92,7 +92,7 @@ class IntegrationCredentialsManager:
|
||||
|
||||
fresh_credentials = oauth_handler.refresh_tokens(credentials)
|
||||
self.store.update_creds(user_id, fresh_credentials)
|
||||
if _lock and _lock.locked():
|
||||
if _lock:
|
||||
_lock.release()
|
||||
|
||||
credentials = fresh_credentials
|
||||
@@ -144,8 +144,7 @@ class IntegrationCredentialsManager:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if lock.locked():
|
||||
lock.release()
|
||||
lock.release()
|
||||
|
||||
def release_all_locks(self):
|
||||
"""Call this on process termination to ensure all locks are released"""
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING
|
||||
from .github import GitHubOAuthHandler
|
||||
from .google import GoogleOAuthHandler
|
||||
from .notion import NotionOAuthHandler
|
||||
from .twitter import TwitterOAuthHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
@@ -16,7 +15,6 @@ HANDLERS_BY_NAME: dict["ProviderName", type["BaseOAuthHandler"]] = {
|
||||
GitHubOAuthHandler,
|
||||
GoogleOAuthHandler,
|
||||
NotionOAuthHandler,
|
||||
TwitterOAuthHandler,
|
||||
]
|
||||
}
|
||||
# --8<-- [end:HANDLERS_BY_NAMEExample]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -23,9 +23,7 @@ class BaseOAuthHandler(ABC):
|
||||
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler3]
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
# --8<-- [end:BaseOAuthHandler3]
|
||||
"""Constructs a login URL that the user can be redirected to"""
|
||||
...
|
||||
@@ -33,7 +31,7 @@ class BaseOAuthHandler(ABC):
|
||||
@abstractmethod
|
||||
# --8<-- [start:BaseOAuthHandler4]
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
self, code: str, scopes: list[str]
|
||||
) -> OAuth2Credentials:
|
||||
# --8<-- [end:BaseOAuthHandler4]
|
||||
"""Exchanges the acquired authorization code from login for a set of tokens"""
|
||||
|
||||
@@ -34,9 +34,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
self.token_url = "https://github.com/login/oauth/access_token"
|
||||
self.revoke_url = "https://api.github.com/applications/{client_id}/token"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
@@ -46,7 +44,7 @@ class GitHubOAuthHandler(BaseOAuthHandler):
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
self, code: str, scopes: list[str]
|
||||
) -> OAuth2Credentials:
|
||||
return self._request_tokens({"code": code, "redirect_uri": self.redirect_uri})
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from google.auth.external_account_authorized_user import (
|
||||
Credentials as ExternalAccountCredentials,
|
||||
@@ -39,9 +38,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
self.token_uri = "https://oauth2.googleapis.com/token"
|
||||
self.revoke_uri = "https://oauth2.googleapis.com/revoke"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
all_scopes = list(set(scopes + self.DEFAULT_SCOPES))
|
||||
logger.debug(f"Setting up OAuth flow with scopes: {all_scopes}")
|
||||
flow = self._setup_oauth_flow(all_scopes)
|
||||
@@ -55,7 +52,7 @@ class GoogleOAuthHandler(BaseOAuthHandler):
|
||||
return authorization_url
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
self, code: str, scopes: list[str]
|
||||
) -> OAuth2Credentials:
|
||||
logger.debug(f"Exchanging code for tokens with scopes: {scopes}")
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from base64 import b64encode
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from backend.data.model import OAuth2Credentials
|
||||
@@ -27,9 +26,7 @@ class NotionOAuthHandler(BaseOAuthHandler):
|
||||
self.auth_base_url = "https://api.notion.com/v1/oauth/authorize"
|
||||
self.token_url = "https://api.notion.com/v1/oauth/token"
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
def get_login_url(self, scopes: list[str], state: str) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
@@ -40,7 +37,7 @@ class NotionOAuthHandler(BaseOAuthHandler):
|
||||
return f"{self.auth_base_url}?{urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
self, code: str, scopes: list[str]
|
||||
) -> OAuth2Credentials:
|
||||
request_body = {
|
||||
"grant_type": "authorization_code",
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from backend.data.model import OAuth2Credentials, ProviderName
|
||||
from backend.integrations.oauth.base import BaseOAuthHandler
|
||||
|
||||
|
||||
class TwitterOAuthHandler(BaseOAuthHandler):
|
||||
PROVIDER_NAME = ProviderName.TWITTER
|
||||
DEFAULT_SCOPES: ClassVar[list[str]] = [
|
||||
"tweet.read",
|
||||
"tweet.write",
|
||||
"tweet.moderate.write",
|
||||
"users.read",
|
||||
"follows.read",
|
||||
"follows.write",
|
||||
"offline.access",
|
||||
"space.read",
|
||||
"mute.read",
|
||||
"mute.write",
|
||||
"like.read",
|
||||
"like.write",
|
||||
"list.read",
|
||||
"list.write",
|
||||
"block.read",
|
||||
"block.write",
|
||||
"bookmark.read",
|
||||
"bookmark.write",
|
||||
]
|
||||
|
||||
AUTHORIZE_URL = "https://twitter.com/i/oauth2/authorize"
|
||||
TOKEN_URL = "https://api.x.com/2/oauth2/token"
|
||||
USERNAME_URL = "https://api.x.com/2/users/me"
|
||||
REVOKE_URL = "https://api.x.com/2/oauth2/revoke"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_login_url(
|
||||
self, scopes: list[str], state: str, code_challenge: Optional[str]
|
||||
) -> str:
|
||||
"""Generate Twitter OAuth 2.0 authorization URL"""
|
||||
# scopes = self.handle_default_scopes(scopes)
|
||||
|
||||
if code_challenge is None:
|
||||
raise ValueError("code_challenge is required for Twitter OAuth")
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": " ".join(self.DEFAULT_SCOPES),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
return f"{self.AUTHORIZE_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def exchange_code_for_tokens(
|
||||
self, code: str, scopes: list[str], code_verifier: Optional[str]
|
||||
) -> OAuth2Credentials:
|
||||
"""Exchange authorization code for access tokens"""
|
||||
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
data = {
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = requests.post(self.TOKEN_URL, headers=headers, data=data, auth=auth)
|
||||
response.raise_for_status()
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
username = self._get_username(tokens["access_token"])
|
||||
|
||||
return OAuth2Credentials(
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens.get("refresh_token"),
|
||||
access_token_expires_at=int(time.time()) + tokens["expires_in"],
|
||||
refresh_token_expires_at=None,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
def _get_username(self, access_token: str) -> str:
|
||||
"""Get the username from the access token"""
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
|
||||
params = {"user.fields": "username"}
|
||||
|
||||
response = requests.get(
|
||||
f"{self.USERNAME_URL}?{urllib.parse.urlencode(params)}", headers=headers
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()["data"]["username"]
|
||||
|
||||
def _refresh_tokens(self, credentials: OAuth2Credentials) -> OAuth2Credentials:
|
||||
"""Refresh access tokens using refresh token"""
|
||||
if not credentials.refresh_token:
|
||||
raise ValueError("No refresh token available")
|
||||
|
||||
header = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.refresh_token.get_secret_value(),
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = requests.post(self.TOKEN_URL, headers=header, data=data, auth=auth)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
print("HTTP Error:", e)
|
||||
print("Response Content:", response.text)
|
||||
raise
|
||||
|
||||
tokens = response.json()
|
||||
|
||||
username = self._get_username(tokens["access_token"])
|
||||
|
||||
return OAuth2Credentials(
|
||||
id=credentials.id,
|
||||
provider=self.PROVIDER_NAME,
|
||||
title=None,
|
||||
username=username,
|
||||
access_token=tokens["access_token"],
|
||||
refresh_token=tokens["refresh_token"],
|
||||
access_token_expires_at=int(time.time()) + tokens["expires_in"],
|
||||
scopes=credentials.scopes,
|
||||
refresh_token_expires_at=None,
|
||||
)
|
||||
|
||||
def revoke_tokens(self, credentials: OAuth2Credentials) -> bool:
|
||||
"""Revoke the access token"""
|
||||
|
||||
header = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
data = {
|
||||
"token": credentials.access_token.get_secret_value(),
|
||||
"token_type_hint": "access_token",
|
||||
}
|
||||
|
||||
auth = (self.client_id, self.client_secret)
|
||||
|
||||
response = requests.post(self.REVOKE_URL, headers=header, data=data, auth=auth)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
print("HTTP Error:", e)
|
||||
print("Response Content:", response.text)
|
||||
raise
|
||||
|
||||
return response.status_code == 200
|
||||
@@ -4,7 +4,6 @@ from enum import Enum
|
||||
# --8<-- [start:ProviderName]
|
||||
class ProviderName(str, Enum):
|
||||
ANTHROPIC = "anthropic"
|
||||
COMPASS = "compass"
|
||||
DISCORD = "discord"
|
||||
D_ID = "d_id"
|
||||
E2B = "e2b"
|
||||
@@ -19,7 +18,6 @@ class ProviderName(str, Enum):
|
||||
JINA = "jina"
|
||||
MEDIUM = "medium"
|
||||
NOTION = "notion"
|
||||
NVIDIA = "nvidia"
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
OPENWEATHERMAP = "openweathermap"
|
||||
@@ -28,6 +26,5 @@ class ProviderName(str, Enum):
|
||||
REPLICATE = "replicate"
|
||||
REVID = "revid"
|
||||
SLANT3D = "slant3d"
|
||||
TWITTER = "twitter"
|
||||
UNREAL_SPEECH = "unreal_speech"
|
||||
# --8<-- [end:ProviderName]
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .compass import CompassWebhookManager
|
||||
from .github import GithubWebhooksManager
|
||||
from .slant3d import Slant3DWebhooksManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..providers import ProviderName
|
||||
from ._base import BaseWebhooksManager
|
||||
from .base import BaseWebhooksManager
|
||||
|
||||
# --8<-- [start:WEBHOOK_MANAGERS_BY_NAME]
|
||||
WEBHOOK_MANAGERS_BY_NAME: dict["ProviderName", type["BaseWebhooksManager"]] = {
|
||||
handler.PROVIDER_NAME: handler
|
||||
for handler in [
|
||||
CompassWebhookManager,
|
||||
GithubWebhooksManager,
|
||||
Slant3DWebhooksManager,
|
||||
]
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
import logging
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.data.model import APIKeyCredentials, Credentials, OAuth2Credentials
|
||||
|
||||
from ._base import WT, BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ManualWebhookManagerBase(BaseWebhooksManager[WT]):
|
||||
async def _register_webhook(
|
||||
self,
|
||||
credentials: Credentials,
|
||||
webhook_type: WT,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
ingress_url: str,
|
||||
secret: str,
|
||||
) -> tuple[str, dict]:
|
||||
print(ingress_url) # FIXME: pass URL to user in front end
|
||||
|
||||
return "", {}
|
||||
|
||||
async def _deregister_webhook(
|
||||
self,
|
||||
webhook: integrations.Webhook,
|
||||
credentials: OAuth2Credentials | APIKeyCredentials,
|
||||
) -> None:
|
||||
pass
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import secrets
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import ClassVar, Generic, Optional, TypeVar
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import Request
|
||||
@@ -10,7 +10,6 @@ from strenum import StrEnum
|
||||
from backend.data import integrations
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks.utils import webhook_ingress_url
|
||||
from backend.util.exceptions import MissingConfigError
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -27,7 +26,7 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
|
||||
WebhookType: WT
|
||||
|
||||
async def get_suitable_auto_webhook(
|
||||
async def get_suitable_webhook(
|
||||
self,
|
||||
user_id: str,
|
||||
credentials: Credentials,
|
||||
@@ -40,34 +39,16 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
"PLATFORM_BASE_URL must be set to use Webhook functionality"
|
||||
)
|
||||
|
||||
if webhook := await integrations.find_webhook_by_credentials_and_props(
|
||||
if webhook := await integrations.find_webhook(
|
||||
credentials.id, webhook_type, resource, events
|
||||
):
|
||||
return webhook
|
||||
return await self._create_webhook(
|
||||
user_id, webhook_type, events, resource, credentials
|
||||
)
|
||||
|
||||
async def get_manual_webhook(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
webhook_type: WT,
|
||||
events: list[str],
|
||||
):
|
||||
if current_webhook := await integrations.find_webhook_by_graph_and_props(
|
||||
graph_id, self.PROVIDER_NAME, webhook_type, events
|
||||
):
|
||||
return current_webhook
|
||||
return await self._create_webhook(
|
||||
user_id,
|
||||
webhook_type,
|
||||
events,
|
||||
register=False,
|
||||
user_id, credentials, webhook_type, resource, events
|
||||
)
|
||||
|
||||
async def prune_webhook_if_dangling(
|
||||
self, webhook_id: str, credentials: Optional[Credentials]
|
||||
self, webhook_id: str, credentials: Credentials
|
||||
) -> bool:
|
||||
webhook = await integrations.get_webhook(webhook_id)
|
||||
if webhook.attached_nodes is None:
|
||||
@@ -76,8 +57,7 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
# Don't prune webhook if in use
|
||||
return False
|
||||
|
||||
if credentials:
|
||||
await self._deregister_webhook(webhook, credentials)
|
||||
await self._deregister_webhook(webhook, credentials)
|
||||
await integrations.delete_webhook(webhook.id)
|
||||
return True
|
||||
|
||||
@@ -155,36 +135,27 @@ class BaseWebhooksManager(ABC, Generic[WT]):
|
||||
async def _create_webhook(
|
||||
self,
|
||||
user_id: str,
|
||||
credentials: Credentials,
|
||||
webhook_type: WT,
|
||||
resource: str,
|
||||
events: list[str],
|
||||
resource: str = "",
|
||||
credentials: Optional[Credentials] = None,
|
||||
register: bool = True,
|
||||
) -> integrations.Webhook:
|
||||
if not app_config.platform_base_url:
|
||||
raise MissingConfigError(
|
||||
"PLATFORM_BASE_URL must be set to use Webhook functionality"
|
||||
)
|
||||
|
||||
id = str(uuid4())
|
||||
secret = secrets.token_hex(32)
|
||||
provider_name = self.PROVIDER_NAME
|
||||
ingress_url = webhook_ingress_url(provider_name=provider_name, webhook_id=id)
|
||||
if register:
|
||||
if not credentials:
|
||||
raise TypeError("credentials are required if register = True")
|
||||
provider_webhook_id, config = await self._register_webhook(
|
||||
credentials, webhook_type, resource, events, ingress_url, secret
|
||||
)
|
||||
else:
|
||||
provider_webhook_id, config = "", {}
|
||||
|
||||
ingress_url = (
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{id}/ingress"
|
||||
)
|
||||
provider_webhook_id, config = await self._register_webhook(
|
||||
credentials, webhook_type, resource, events, ingress_url, secret
|
||||
)
|
||||
return await integrations.create_webhook(
|
||||
integrations.Webhook(
|
||||
id=id,
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
credentials_id=credentials.id if credentials else "",
|
||||
credentials_id=credentials.id,
|
||||
webhook_type=webhook_type,
|
||||
resource=resource,
|
||||
events=events,
|
||||
@@ -1,30 +0,0 @@
|
||||
import logging
|
||||
|
||||
from fastapi import Request
|
||||
from strenum import StrEnum
|
||||
|
||||
from backend.data import integrations
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._manual_base import ManualWebhookManagerBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompassWebhookType(StrEnum):
|
||||
TRANSCRIPTION = "transcription"
|
||||
TASK = "task"
|
||||
|
||||
|
||||
class CompassWebhookManager(ManualWebhookManagerBase):
|
||||
PROVIDER_NAME = ProviderName.COMPASS
|
||||
WebhookType = CompassWebhookType
|
||||
|
||||
@classmethod
|
||||
async def validate_payload(
|
||||
cls, webhook: integrations.Webhook, request: Request
|
||||
) -> tuple[dict, str]:
|
||||
payload = await request.json()
|
||||
event_type = CompassWebhookType.TRANSCRIPTION # currently the only type
|
||||
|
||||
return payload, event_type
|
||||
@@ -10,7 +10,7 @@ from backend.data import integrations
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
from ._base import BaseWebhooksManager
|
||||
from .base import BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, cast
|
||||
|
||||
from backend.data.block import BlockSchema, BlockWebhookConfig, get_block
|
||||
from backend.data.block import get_block
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.data.model import CREDENTIALS_FIELD_NAME
|
||||
from backend.integrations.webhooks import WEBHOOK_MANAGERS_BY_NAME
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.graph import GraphModel, NodeModel
|
||||
from backend.data.model import Credentials
|
||||
|
||||
from ._base import BaseWebhooksManager
|
||||
from .base import BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,28 +30,14 @@ async def on_graph_activate(
|
||||
# Compare nodes in new_graph_version with previous_graph_version
|
||||
updated_nodes = []
|
||||
for new_node in graph.nodes:
|
||||
block = get_block(new_node.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} is instance of unknown block #{new_node.block_id}"
|
||||
)
|
||||
block_input_schema = cast(BlockSchema, block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
if (
|
||||
# Webhook-triggered blocks are only allowed to have 1 credentials input
|
||||
(
|
||||
creds_field_name := next(
|
||||
iter(block_input_schema.get_credentials_fields()), None
|
||||
if creds_meta := new_node.input_default.get(CREDENTIALS_FIELD_NAME):
|
||||
node_credentials = get_credentials(creds_meta["id"])
|
||||
if not node_credentials:
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} updated with non-existent "
|
||||
f"credentials #{node_credentials}"
|
||||
)
|
||||
)
|
||||
and (creds_meta := new_node.input_default.get(creds_field_name))
|
||||
and not (node_credentials := get_credentials(creds_meta["id"]))
|
||||
):
|
||||
raise ValueError(
|
||||
f"Node #{new_node.id} input '{creds_field_name}' updated with "
|
||||
f"non-existent credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_activate(
|
||||
graph.user_id, new_node, credentials=node_credentials
|
||||
@@ -75,28 +62,14 @@ async def on_graph_deactivate(
|
||||
"""
|
||||
updated_nodes = []
|
||||
for node in graph.nodes:
|
||||
block = get_block(node.block_id)
|
||||
if not block:
|
||||
raise ValueError(
|
||||
f"Node #{node.id} is instance of unknown block #{node.block_id}"
|
||||
)
|
||||
block_input_schema = cast(BlockSchema, block.input_schema)
|
||||
|
||||
node_credentials = None
|
||||
if (
|
||||
# Webhook-triggered blocks are only allowed to have 1 credentials input
|
||||
(
|
||||
creds_field_name := next(
|
||||
iter(block_input_schema.get_credentials_fields()), None
|
||||
if creds_meta := node.input_default.get(CREDENTIALS_FIELD_NAME):
|
||||
node_credentials = get_credentials(creds_meta["id"])
|
||||
if not node_credentials:
|
||||
logger.error(
|
||||
f"Node #{node.id} referenced non-existent "
|
||||
f"credentials #{creds_meta['id']}"
|
||||
)
|
||||
)
|
||||
and (creds_meta := node.input_default.get(creds_field_name))
|
||||
and not (node_credentials := get_credentials(creds_meta["id"]))
|
||||
):
|
||||
logger.error(
|
||||
f"Node #{node.id} input '{creds_field_name}' referenced non-existent "
|
||||
f"credentials #{creds_meta['id']}"
|
||||
)
|
||||
|
||||
updated_node = await on_node_deactivate(node, credentials=node_credentials)
|
||||
updated_nodes.append(updated_node)
|
||||
@@ -135,82 +108,50 @@ async def on_node_activate(
|
||||
|
||||
webhooks_manager = WEBHOOK_MANAGERS_BY_NAME[provider]()
|
||||
|
||||
if auto_setup_webhook := isinstance(block.webhook_config, BlockWebhookConfig):
|
||||
try:
|
||||
resource = block.webhook_config.resource_format.format(**node.input_default)
|
||||
except KeyError:
|
||||
resource = None
|
||||
logger.debug(
|
||||
f"Constructed resource string {resource} from input {node.input_default}"
|
||||
)
|
||||
else:
|
||||
resource = "" # not relevant for manual webhooks
|
||||
|
||||
block_input_schema = cast(BlockSchema, block.input_schema)
|
||||
credentials_field_name = next(iter(block_input_schema.get_credentials_fields()), "")
|
||||
credentials_meta = (
|
||||
node.input_default.get(credentials_field_name)
|
||||
if credentials_field_name
|
||||
else None
|
||||
try:
|
||||
resource = block.webhook_config.resource_format.format(**node.input_default)
|
||||
except KeyError:
|
||||
resource = None
|
||||
logger.debug(
|
||||
f"Constructed resource string {resource} from input {node.input_default}"
|
||||
)
|
||||
|
||||
event_filter_input_name = block.webhook_config.event_filter_input
|
||||
has_everything_for_webhook = (
|
||||
resource is not None
|
||||
and (credentials_meta or not credentials_field_name)
|
||||
and (
|
||||
not event_filter_input_name
|
||||
or (
|
||||
event_filter_input_name in node.input_default
|
||||
and any(
|
||||
is_on
|
||||
for is_on in node.input_default[event_filter_input_name].values()
|
||||
)
|
||||
)
|
||||
)
|
||||
and CREDENTIALS_FIELD_NAME in node.input_default
|
||||
and event_filter_input_name in node.input_default
|
||||
and any(is_on for is_on in node.input_default[event_filter_input_name].values())
|
||||
)
|
||||
|
||||
if has_everything_for_webhook and resource is not None:
|
||||
if has_everything_for_webhook and resource:
|
||||
logger.debug(f"Node #{node} has everything for a webhook!")
|
||||
if credentials_meta and not credentials:
|
||||
if not credentials:
|
||||
credentials_meta = node.input_default[CREDENTIALS_FIELD_NAME]
|
||||
raise ValueError(
|
||||
f"Cannot set up webhook for node #{node.id}: "
|
||||
f"credentials #{credentials_meta['id']} not available"
|
||||
)
|
||||
|
||||
if event_filter_input_name:
|
||||
# Shape of the event filter is enforced in Block.__init__
|
||||
event_filter = cast(dict, node.input_default[event_filter_input_name])
|
||||
events = [
|
||||
block.webhook_config.event_format.format(event=event)
|
||||
for event, enabled in event_filter.items()
|
||||
if enabled is True
|
||||
]
|
||||
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
|
||||
else:
|
||||
events = []
|
||||
# Shape of the event filter is enforced in Block.__init__
|
||||
event_filter = cast(dict, node.input_default[event_filter_input_name])
|
||||
events = [
|
||||
block.webhook_config.event_format.format(event=event)
|
||||
for event, enabled in event_filter.items()
|
||||
if enabled is True
|
||||
]
|
||||
logger.debug(f"Webhook events to subscribe to: {', '.join(events)}")
|
||||
|
||||
# Find/make and attach a suitable webhook to the node
|
||||
if auto_setup_webhook:
|
||||
assert credentials is not None
|
||||
new_webhook = await webhooks_manager.get_suitable_auto_webhook(
|
||||
user_id,
|
||||
credentials,
|
||||
block.webhook_config.webhook_type,
|
||||
resource,
|
||||
events,
|
||||
)
|
||||
else:
|
||||
# Manual webhook -> no credentials -> don't register but do create
|
||||
new_webhook = await webhooks_manager.get_manual_webhook(
|
||||
user_id,
|
||||
node.graph_id,
|
||||
block.webhook_config.webhook_type,
|
||||
events,
|
||||
)
|
||||
new_webhook = await webhooks_manager.get_suitable_webhook(
|
||||
user_id,
|
||||
credentials,
|
||||
block.webhook_config.webhook_type,
|
||||
resource,
|
||||
events,
|
||||
)
|
||||
logger.debug(f"Acquired webhook: {new_webhook}")
|
||||
return await set_node_webhook(node.id, new_webhook.id)
|
||||
else:
|
||||
logger.debug(f"Node #{node.id} does not have everything for a webhook")
|
||||
|
||||
return node
|
||||
|
||||
@@ -253,16 +194,12 @@ async def on_node_deactivate(
|
||||
updated_node = await set_node_webhook(node.id, None)
|
||||
|
||||
# Prune and deregister the webhook if it is no longer used anywhere
|
||||
logger.debug("Pruning and deregistering webhook if dangling")
|
||||
webhook = node.webhook
|
||||
logger.debug(
|
||||
f"Pruning{' and deregistering' if credentials else ''} "
|
||||
f"webhook #{webhook.id}"
|
||||
)
|
||||
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
|
||||
if (
|
||||
cast(BlockSchema, block.input_schema).get_credentials_fields()
|
||||
and not credentials
|
||||
):
|
||||
if credentials:
|
||||
logger.debug(f"Pruning webhook #{webhook.id} with credentials")
|
||||
await webhooks_manager.prune_webhook_if_dangling(webhook.id, credentials)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cannot deregister webhook #{webhook.id}: credentials "
|
||||
f"#{webhook.credentials_id} not available "
|
||||
|
||||
@@ -6,7 +6,7 @@ from fastapi import Request
|
||||
from backend.data import integrations
|
||||
from backend.data.model import APIKeyCredentials, Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks._base import BaseWebhooksManager
|
||||
from backend.integrations.webhooks.base import BaseWebhooksManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util.settings import Config
|
||||
|
||||
app_config = Config()
|
||||
|
||||
|
||||
# TODO: add test to assert this matches the actual API route
|
||||
def webhook_ingress_url(provider_name: ProviderName, webhook_id: str) -> str:
|
||||
return (
|
||||
f"{app_config.platform_base_url}/api/integrations/{provider_name.value}"
|
||||
f"/webhooks/{webhook_id}/ingress"
|
||||
)
|
||||
@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field, SecretStr
|
||||
from backend.data.graph import set_node_webhook
|
||||
from backend.data.integrations import (
|
||||
WebhookEvent,
|
||||
get_all_webhooks_by_creds,
|
||||
get_all_webhooks,
|
||||
get_webhook,
|
||||
publish_webhook_event,
|
||||
wait_for_webhook_event,
|
||||
@@ -60,12 +60,11 @@ def login(
|
||||
requested_scopes = scopes.split(",") if scopes else []
|
||||
|
||||
# Generate and store a secure random state token along with the scopes
|
||||
state_token, code_challenge = creds_manager.store.store_state_token(
|
||||
state_token = creds_manager.store.store_state_token(
|
||||
user_id, provider, requested_scopes
|
||||
)
|
||||
login_url = handler.get_login_url(
|
||||
requested_scopes, state_token, code_challenge=code_challenge
|
||||
)
|
||||
|
||||
login_url = handler.get_login_url(requested_scopes, state_token)
|
||||
|
||||
return LoginResponse(login_url=login_url, state_token=state_token)
|
||||
|
||||
@@ -93,21 +92,19 @@ def callback(
|
||||
handler = _get_provider_oauth_handler(request, provider)
|
||||
|
||||
# Verify the state token
|
||||
valid_state = creds_manager.store.verify_state_token(user_id, state_token, provider)
|
||||
|
||||
if not valid_state:
|
||||
if not creds_manager.store.verify_state_token(user_id, state_token, provider):
|
||||
logger.warning(f"Invalid or expired state token for user {user_id}")
|
||||
raise HTTPException(status_code=400, detail="Invalid or expired state token")
|
||||
|
||||
try:
|
||||
scopes = valid_state.scopes
|
||||
scopes = creds_manager.store.get_any_valid_scopes_from_state_token(
|
||||
user_id, state_token, provider
|
||||
)
|
||||
logger.debug(f"Retrieved scopes from state token: {scopes}")
|
||||
|
||||
scopes = handler.handle_default_scopes(scopes)
|
||||
|
||||
credentials = handler.exchange_code_for_tokens(
|
||||
code, scopes, valid_state.code_verifier
|
||||
)
|
||||
|
||||
credentials = handler.exchange_code_for_tokens(code, scopes)
|
||||
logger.debug(f"Received credentials with final scopes: {credentials.scopes}")
|
||||
|
||||
# Check if the granted scopes are sufficient for the requested scopes
|
||||
@@ -366,7 +363,7 @@ async def remove_all_webhooks_for_credentials(
|
||||
Raises:
|
||||
NeedConfirmation: If any of the webhooks are still in use and `force` is `False`
|
||||
"""
|
||||
webhooks = await get_all_webhooks_by_creds(credentials.id)
|
||||
webhooks = await get_all_webhooks(credentials.id)
|
||||
if credentials.provider not in WEBHOOK_MANAGERS_BY_NAME:
|
||||
if webhooks:
|
||||
logger.error(
|
||||
|
||||
@@ -16,7 +16,6 @@ import backend.data.db
|
||||
import backend.data.graph
|
||||
import backend.data.user
|
||||
import backend.server.routers.v1
|
||||
import backend.server.v2.library.routes
|
||||
import backend.server.v2.store.routes
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
@@ -90,9 +89,6 @@ app.include_router(backend.server.routers.v1.v1_router, tags=["v1"], prefix="/ap
|
||||
app.include_router(
|
||||
backend.server.v2.store.routes.router, tags=["v2"], prefix="/api/store"
|
||||
)
|
||||
app.include_router(
|
||||
backend.server.v2.library.routes.router, tags=["v2"], prefix="/api/library"
|
||||
)
|
||||
|
||||
|
||||
@app.get(path="/health", tags=["health"], dependencies=[])
|
||||
|
||||
@@ -149,7 +149,7 @@ class DeleteGraphResponse(TypedDict):
|
||||
@v1_router.get(path="/graphs", tags=["graphs"], dependencies=[Depends(auth_middleware)])
|
||||
async def get_graphs(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
) -> Sequence[graph_db.Graph]:
|
||||
return await graph_db.get_graphs(filter_by="active", user_id=user_id)
|
||||
|
||||
|
||||
@@ -166,9 +166,9 @@ async def get_graph(
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
version: int | None = None,
|
||||
hide_credentials: bool = False,
|
||||
) -> graph_db.GraphModel:
|
||||
) -> graph_db.Graph:
|
||||
graph = await graph_db.get_graph(
|
||||
graph_id, version, user_id=user_id, for_export=hide_credentials
|
||||
graph_id, version, user_id=user_id, hide_credentials=hide_credentials
|
||||
)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
@@ -187,7 +187,7 @@ async def get_graph(
|
||||
)
|
||||
async def get_graph_all_versions(
|
||||
graph_id: str, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
) -> Sequence[graph_db.Graph]:
|
||||
graphs = await graph_db.get_graph_all_versions(graph_id, user_id=user_id)
|
||||
if not graphs:
|
||||
raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.")
|
||||
@@ -199,7 +199,7 @@ async def get_graph_all_versions(
|
||||
)
|
||||
async def create_new_graph(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.GraphModel:
|
||||
) -> graph_db.Graph:
|
||||
return await do_create_graph(create_graph, is_template=False, user_id=user_id)
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ async def do_create_graph(
|
||||
# user_id doesn't have to be annotated like on other endpoints,
|
||||
# because create_graph isn't used directly as an endpoint
|
||||
user_id: str,
|
||||
) -> graph_db.GraphModel:
|
||||
) -> graph_db.Graph:
|
||||
if create_graph.graph:
|
||||
graph = graph_db.make_graph_model(create_graph.graph, user_id)
|
||||
elif create_graph.template_id:
|
||||
@@ -270,7 +270,7 @@ async def update_graph(
|
||||
graph_id: str,
|
||||
graph: graph_db.Graph,
|
||||
user_id: Annotated[str, Depends(get_user_id)],
|
||||
) -> graph_db.GraphModel:
|
||||
) -> graph_db.Graph:
|
||||
# Sanity check
|
||||
if graph.id and graph.id != graph_id:
|
||||
raise HTTPException(400, detail="Graph ID does not match ID in URI")
|
||||
@@ -440,7 +440,7 @@ async def get_graph_run_node_execution_results(
|
||||
)
|
||||
async def get_templates(
|
||||
user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> Sequence[graph_db.GraphModel]:
|
||||
) -> Sequence[graph_db.Graph]:
|
||||
return await graph_db.get_graphs(filter_by="template", user_id=user_id)
|
||||
|
||||
|
||||
@@ -449,9 +449,7 @@ async def get_templates(
|
||||
tags=["templates", "graphs"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
async def get_template(
|
||||
graph_id: str, version: int | None = None
|
||||
) -> graph_db.GraphModel:
|
||||
async def get_template(graph_id: str, version: int | None = None) -> graph_db.Graph:
|
||||
graph = await graph_db.get_graph(graph_id, version, template=True)
|
||||
if not graph:
|
||||
raise HTTPException(status_code=404, detail=f"Template #{graph_id} not found.")
|
||||
@@ -465,7 +463,7 @@ async def get_template(
|
||||
)
|
||||
async def create_new_template(
|
||||
create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)]
|
||||
) -> graph_db.GraphModel:
|
||||
) -> graph_db.Graph:
|
||||
return await do_create_graph(create_graph, is_template=True, user_id=user_id)
|
||||
|
||||
|
||||
@@ -541,7 +539,7 @@ def get_execution_schedules(
|
||||
|
||||
@v1_router.post(
|
||||
"/api-keys",
|
||||
response_model=CreateAPIKeyResponse,
|
||||
response_model=list[CreateAPIKeyResponse] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -583,7 +581,7 @@ async def get_api_keys(
|
||||
|
||||
@v1_router.get(
|
||||
"/api-keys/{key_id}",
|
||||
response_model=APIKeyWithoutHash,
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -604,7 +602,7 @@ async def get_api_key(
|
||||
|
||||
@v1_router.delete(
|
||||
"/api-keys/{key_id}",
|
||||
response_model=APIKeyWithoutHash,
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -626,7 +624,7 @@ async def delete_api_key(
|
||||
|
||||
@v1_router.post(
|
||||
"/api-keys/{key_id}/suspend",
|
||||
response_model=APIKeyWithoutHash,
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
@@ -648,7 +646,7 @@ async def suspend_key(
|
||||
|
||||
@v1_router.put(
|
||||
"/api-keys/{key_id}/permissions",
|
||||
response_model=APIKeyWithoutHash,
|
||||
response_model=list[APIKeyWithoutHash] | dict[str, str],
|
||||
tags=["api-keys"],
|
||||
dependencies=[Depends(auth_middleware)],
|
||||
)
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.data.graph
|
||||
import backend.data.includes
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.store.exceptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_library_agents(
|
||||
user_id: str,
|
||||
) -> List[backend.server.v2.library.model.LibraryAgent]:
|
||||
"""
|
||||
Returns all agents (AgentGraph) that belong to the user and all agents in their library (UserAgent table)
|
||||
"""
|
||||
logger.debug(f"Getting library agents for user {user_id}")
|
||||
|
||||
try:
|
||||
# Get agents created by user with nodes and links
|
||||
user_created = await prisma.models.AgentGraph.prisma().find_many(
|
||||
where=prisma.types.AgentGraphWhereInput(userId=user_id, isActive=True),
|
||||
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
|
||||
# Get agents in user's library with nodes and links
|
||||
library_agents = await prisma.models.UserAgent.prisma().find_many(
|
||||
where=prisma.types.UserAgentWhereInput(
|
||||
userId=user_id, isDeleted=False, isArchived=False
|
||||
),
|
||||
include={
|
||||
"Agent": {
|
||||
"include": {
|
||||
"AgentNodes": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Webhook": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Convert to Graph models first
|
||||
graphs = []
|
||||
|
||||
# Add user created agents
|
||||
for agent in user_created:
|
||||
try:
|
||||
graphs.append(backend.data.graph.GraphModel.from_db(agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing user created agent {agent.id}: {e}")
|
||||
continue
|
||||
|
||||
# Add library agents
|
||||
for agent in library_agents:
|
||||
if agent.Agent:
|
||||
try:
|
||||
graphs.append(backend.data.graph.GraphModel.from_db(agent.Agent))
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing library agent {agent.agentId}: {e}")
|
||||
continue
|
||||
|
||||
# Convert Graph models to LibraryAgent models
|
||||
result = []
|
||||
for graph in graphs:
|
||||
result.append(
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id=graph.id,
|
||||
version=graph.version,
|
||||
is_active=graph.is_active,
|
||||
name=graph.name,
|
||||
description=graph.description,
|
||||
isCreatedByUser=any(a.id == graph.id for a in user_created),
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Found {len(result)} library agents")
|
||||
return result
|
||||
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error getting library agents: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch library agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def add_agent_to_library(store_listing_version_id: str, user_id: str) -> None:
|
||||
"""
|
||||
Finds the agent from the store listing version and adds it to the user's library (UserAgent table)
|
||||
if they don't already have it
|
||||
"""
|
||||
logger.debug(
|
||||
f"Adding agent from store listing version {store_listing_version_id} to library for user {user_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get store listing version to find agent
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
logger.warning(
|
||||
f"Store listing version not found: {store_listing_version_id}"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.AgentNotFoundError(
|
||||
f"Store listing version {store_listing_version_id} not found"
|
||||
)
|
||||
|
||||
agent = store_listing_version.Agent
|
||||
|
||||
if agent.userId == user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} cannot add their own agent to their library"
|
||||
)
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Cannot add own agent to library"
|
||||
)
|
||||
|
||||
# Check if user already has this agent
|
||||
existing_user_agent = await prisma.models.UserAgent.prisma().find_first(
|
||||
where={
|
||||
"userId": user_id,
|
||||
"agentId": agent.id,
|
||||
"agentVersion": agent.version,
|
||||
}
|
||||
)
|
||||
|
||||
if existing_user_agent:
|
||||
logger.debug(
|
||||
f"User {user_id} already has agent {agent.id} in their library"
|
||||
)
|
||||
return
|
||||
|
||||
# Create UserAgent entry
|
||||
await prisma.models.UserAgent.prisma().create(
|
||||
data=prisma.types.UserAgentCreateInput(
|
||||
userId=user_id,
|
||||
agentId=agent.id,
|
||||
agentVersion=agent.version,
|
||||
isCreatedByUser=False,
|
||||
)
|
||||
)
|
||||
logger.debug(f"Added agent {agent.id} to library for user {user_id}")
|
||||
|
||||
except backend.server.v2.store.exceptions.AgentNotFoundError:
|
||||
raise
|
||||
except prisma.errors.PrismaError as e:
|
||||
logger.error(f"Database error adding agent to library: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to add agent to library"
|
||||
) from e
|
||||
@@ -1,197 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import pytest
|
||||
from prisma import Prisma
|
||||
|
||||
import backend.data.includes
|
||||
import backend.server.v2.library.db as db
|
||||
import backend.server.v2.store.exceptions
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def setup_prisma():
|
||||
# Don't register client if already registered
|
||||
try:
|
||||
Prisma()
|
||||
except prisma.errors.ClientAlreadyRegisteredError:
|
||||
pass
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_library_agents(mocker):
|
||||
# Mock data
|
||||
mock_user_created = [
|
||||
prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
userId="test-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
isTemplate=False,
|
||||
)
|
||||
]
|
||||
|
||||
mock_library_agents = [
|
||||
prisma.models.UserAgent(
|
||||
id="ua1",
|
||||
userId="test-user",
|
||||
agentId="agent2",
|
||||
agentVersion=1,
|
||||
isCreatedByUser=False,
|
||||
isDeleted=False,
|
||||
isArchived=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
isFavorite=False,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
id="agent2",
|
||||
version=1,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
userId="other-user",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
isTemplate=False,
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Mock prisma calls
|
||||
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||
mock_agent_graph.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_user_created
|
||||
)
|
||||
|
||||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
|
||||
mock_user_agent.return_value.find_many = mocker.AsyncMock(
|
||||
return_value=mock_library_agents
|
||||
)
|
||||
|
||||
# Call function
|
||||
result = await db.get_library_agents("test-user")
|
||||
|
||||
# Verify results
|
||||
assert len(result) == 2
|
||||
assert result[0].id == "agent1"
|
||||
assert result[0].name == "Test Agent 1"
|
||||
assert result[0].description == "Test Description 1"
|
||||
assert result[0].isCreatedByUser is True
|
||||
assert result[1].id == "agent2"
|
||||
assert result[1].name == "Test Agent 2"
|
||||
assert result[1].description == "Test Description 2"
|
||||
assert result[1].isCreatedByUser is False
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_agent_graph.return_value.find_many.assert_called_once_with(
|
||||
where=prisma.types.AgentGraphWhereInput(userId="test-user", isActive=True),
|
||||
include=backend.data.includes.AGENT_GRAPH_INCLUDE,
|
||||
)
|
||||
mock_user_agent.return_value.find_many.assert_called_once_with(
|
||||
where=prisma.types.UserAgentWhereInput(
|
||||
userId="test-user", isDeleted=False, isArchived=False
|
||||
),
|
||||
include={
|
||||
"Agent": {
|
||||
"include": {
|
||||
"AgentNodes": {
|
||||
"include": {
|
||||
"Input": True,
|
||||
"Output": True,
|
||||
"Webhook": True,
|
||||
"AgentBlock": True,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_agent_to_library(mocker):
|
||||
# Mock data
|
||||
mock_store_listing = prisma.models.StoreListingVersion(
|
||||
id="version123",
|
||||
version=1,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
agentId="agent1",
|
||||
agentVersion=1,
|
||||
slug="test-agent",
|
||||
name="Test Agent",
|
||||
subHeading="Test Agent Subheading",
|
||||
imageUrls=["https://example.com/image.jpg"],
|
||||
description="Test Description",
|
||||
categories=["test"],
|
||||
isFeatured=False,
|
||||
isDeleted=False,
|
||||
isAvailable=True,
|
||||
isApproved=True,
|
||||
Agent=prisma.models.AgentGraph(
|
||||
id="agent1",
|
||||
version=1,
|
||||
name="Test Agent",
|
||||
description="Test Description",
|
||||
userId="creator",
|
||||
isActive=True,
|
||||
createdAt=datetime.now(),
|
||||
isTemplate=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Mock prisma calls
|
||||
mock_store_listing_version = mocker.patch(
|
||||
"prisma.models.StoreListingVersion.prisma"
|
||||
)
|
||||
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock(
|
||||
return_value=mock_store_listing
|
||||
)
|
||||
|
||||
mock_user_agent = mocker.patch("prisma.models.UserAgent.prisma")
|
||||
mock_user_agent.return_value.find_first = mocker.AsyncMock(return_value=None)
|
||||
mock_user_agent.return_value.create = mocker.AsyncMock()
|
||||
|
||||
# Call function
|
||||
await db.add_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mocks called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
)
|
||||
mock_user_agent.return_value.find_first.assert_called_once_with(
|
||||
where={
|
||||
"userId": "test-user",
|
||||
"agentId": "agent1",
|
||||
"agentVersion": 1,
|
||||
}
|
||||
)
|
||||
mock_user_agent.return_value.create.assert_called_once_with(
|
||||
data=prisma.types.UserAgentCreateInput(
|
||||
userId="test-user", agentId="agent1", agentVersion=1, isCreatedByUser=False
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_agent_to_library_not_found(mocker):
|
||||
# Mock prisma calls
|
||||
mock_store_listing_version = mocker.patch(
|
||||
"prisma.models.StoreListingVersion.prisma"
|
||||
)
|
||||
mock_store_listing_version.return_value.find_unique = mocker.AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
|
||||
# Call function and verify exception
|
||||
with pytest.raises(backend.server.v2.store.exceptions.AgentNotFoundError):
|
||||
await db.add_agent_to_library("version123", "test-user")
|
||||
|
||||
# Verify mock called correctly
|
||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||
where={"id": "version123"}, include={"Agent": True}
|
||||
)
|
||||
@@ -1,16 +0,0 @@
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
id: str # Changed from agent_id to match GraphMeta
|
||||
version: int # Changed from agent_version to match GraphMeta
|
||||
is_active: bool # Added to match GraphMeta
|
||||
name: str
|
||||
description: str
|
||||
|
||||
isCreatedByUser: bool
|
||||
# Made input_schema and output_schema match GraphMeta's type
|
||||
input_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
output_schema: dict[str, typing.Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
@@ -1,43 +0,0 @@
|
||||
import backend.server.v2.library.model
|
||||
|
||||
|
||||
def test_library_agent():
|
||||
agent = backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-123",
|
||||
version=1,
|
||||
is_active=True,
|
||||
name="Test Agent",
|
||||
description="Test description",
|
||||
isCreatedByUser=False,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
)
|
||||
assert agent.id == "test-agent-123"
|
||||
assert agent.version == 1
|
||||
assert agent.is_active is True
|
||||
assert agent.name == "Test Agent"
|
||||
assert agent.description == "Test description"
|
||||
assert agent.isCreatedByUser is False
|
||||
assert agent.input_schema == {"type": "object", "properties": {}}
|
||||
assert agent.output_schema == {"type": "object", "properties": {}}
|
||||
|
||||
|
||||
def test_library_agent_with_user_created():
|
||||
agent = backend.server.v2.library.model.LibraryAgent(
|
||||
id="user-agent-456",
|
||||
version=2,
|
||||
is_active=True,
|
||||
name="User Created Agent",
|
||||
description="An agent created by the user",
|
||||
isCreatedByUser=True,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
)
|
||||
assert agent.id == "user-agent-456"
|
||||
assert agent.version == 2
|
||||
assert agent.is_active is True
|
||||
assert agent.name == "User Created Agent"
|
||||
assert agent.description == "An agent created by the user"
|
||||
assert agent.isCreatedByUser is True
|
||||
assert agent.input_schema == {"type": "object", "properties": {}}
|
||||
assert agent.output_schema == {"type": "object", "properties": {}}
|
||||
@@ -1,123 +0,0 @@
|
||||
import logging
|
||||
import typing
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import prisma
|
||||
|
||||
import backend.data.graph
|
||||
import backend.integrations.creds_manager
|
||||
import backend.integrations.webhooks.graph_lifecycle_hooks
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = fastapi.APIRouter()
|
||||
integration_creds_manager = (
|
||||
backend.integrations.creds_manager.IntegrationCredentialsManager()
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def get_library_agents(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
) -> typing.Sequence[backend.server.v2.library.model.LibraryAgent]:
|
||||
"""
|
||||
Get all agents in the user's library, including both created and saved agents.
|
||||
"""
|
||||
try:
|
||||
agents = await backend.server.v2.library.db.get_library_agents(user_id)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting library agents")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to get library agents"
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{store_listing_version_id}",
|
||||
tags=["library", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
status_code=201,
|
||||
)
|
||||
async def add_agent_to_library(
|
||||
store_listing_version_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> fastapi.Response:
|
||||
"""
|
||||
Add an agent from the store to the user's library.
|
||||
|
||||
Args:
|
||||
store_listing_version_id (str): ID of the store listing version to add
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
fastapi.Response: 201 status code on success
|
||||
|
||||
Raises:
|
||||
HTTPException: If there is an error adding the agent to the library
|
||||
"""
|
||||
try:
|
||||
# Get the graph from the store listing
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
agent = store_listing_version.Agent
|
||||
|
||||
if agent.userId == user_id:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=400, detail="Cannot add own agent to library"
|
||||
)
|
||||
|
||||
# Create a new graph from the template
|
||||
graph = await backend.data.graph.get_graph(
|
||||
agent.id, agent.version, template=True, user_id=user_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent {agent.id} not found"
|
||||
)
|
||||
|
||||
# Create a deep copy with new IDs
|
||||
graph.version = 1
|
||||
graph.is_template = False
|
||||
graph.is_active = True
|
||||
graph.reassign_ids(user_id=user_id, reassign_graph_id=True)
|
||||
|
||||
# Save the new graph
|
||||
graph = await backend.data.graph.create_graph(graph, user_id=user_id)
|
||||
graph = (
|
||||
await backend.integrations.webhooks.graph_lifecycle_hooks.on_graph_activate(
|
||||
graph,
|
||||
get_credentials=lambda id: integration_creds_manager.get(user_id, id),
|
||||
)
|
||||
)
|
||||
|
||||
return fastapi.Response(status_code=201)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst adding agent to library")
|
||||
raise fastapi.HTTPException(
|
||||
status_code=500, detail="Failed to add agent to library"
|
||||
)
|
||||
@@ -1,106 +0,0 @@
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import fastapi.testclient
|
||||
import pytest
|
||||
import pytest_mock
|
||||
|
||||
import backend.server.v2.library.db
|
||||
import backend.server.v2.library.model
|
||||
import backend.server.v2.library.routes
|
||||
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(backend.server.v2.library.routes.router)
|
||||
|
||||
client = fastapi.testclient.TestClient(app)
|
||||
|
||||
|
||||
def override_auth_middleware():
|
||||
"""Override auth middleware for testing"""
|
||||
return {"sub": "test-user-id"}
|
||||
|
||||
|
||||
def override_get_user_id():
|
||||
"""Override get_user_id for testing"""
|
||||
return "test-user-id"
|
||||
|
||||
|
||||
app.dependency_overrides[autogpt_libs.auth.middleware.auth_middleware] = (
|
||||
override_auth_middleware
|
||||
)
|
||||
app.dependency_overrides[autogpt_libs.auth.depends.get_user_id] = override_get_user_id
|
||||
|
||||
|
||||
def test_get_library_agents_success(mocker: pytest_mock.MockFixture):
|
||||
mocked_value = [
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-1",
|
||||
version=1,
|
||||
is_active=True,
|
||||
name="Test Agent 1",
|
||||
description="Test Description 1",
|
||||
isCreatedByUser=True,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
),
|
||||
backend.server.v2.library.model.LibraryAgent(
|
||||
id="test-agent-2",
|
||||
version=1,
|
||||
is_active=True,
|
||||
name="Test Agent 2",
|
||||
description="Test Description 2",
|
||||
isCreatedByUser=False,
|
||||
input_schema={"type": "object", "properties": {}},
|
||||
output_schema={"type": "object", "properties": {}},
|
||||
),
|
||||
]
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
|
||||
mock_db_call.return_value = mocked_value
|
||||
|
||||
response = client.get("/agents")
|
||||
assert response.status_code == 200
|
||||
|
||||
data = [
|
||||
backend.server.v2.library.model.LibraryAgent.model_validate(agent)
|
||||
for agent in response.json()
|
||||
]
|
||||
assert len(data) == 2
|
||||
assert data[0].id == "test-agent-1"
|
||||
assert data[0].isCreatedByUser is True
|
||||
assert data[1].id == "test-agent-2"
|
||||
assert data[1].isCreatedByUser is False
|
||||
mock_db_call.assert_called_once_with("test-user-id")
|
||||
|
||||
|
||||
def test_get_library_agents_error(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.get_library_agents")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.get("/agents")
|
||||
assert response.status_code == 500
|
||||
mock_db_call.assert_called_once_with("test-user-id")
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mocker Not implemented")
|
||||
def test_add_agent_to_library_success(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call.return_value = None
|
||||
|
||||
response = client.post("/agents/test-version-id")
|
||||
assert response.status_code == 201
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Mocker Not implemented")
|
||||
def test_add_agent_to_library_error(mocker: pytest_mock.MockFixture):
|
||||
mock_db_call = mocker.patch("backend.server.v2.library.db.add_agent_to_library")
|
||||
mock_db_call.side_effect = Exception("Test error")
|
||||
|
||||
response = client.post("/agents/test-version-id")
|
||||
assert response.status_code == 500
|
||||
assert response.json()["detail"] == "Failed to add agent to library"
|
||||
mock_db_call.assert_called_once_with(
|
||||
store_listing_version_id="test-version-id", user_id="test-user-id"
|
||||
)
|
||||
@@ -1,53 +0,0 @@
|
||||
# Store Module
|
||||
|
||||
This module implements the backend API for the AutoGPT Store, handling agents, creators, profiles, submissions and media uploads.
|
||||
|
||||
## Files
|
||||
|
||||
### routes.py
|
||||
Contains the FastAPI route handlers for the store API endpoints:
|
||||
|
||||
- Profile endpoints for managing user profiles
|
||||
- Agent endpoints for browsing and retrieving store agents
|
||||
- Creator endpoints for browsing and retrieving creator details
|
||||
- Store submission endpoints for submitting agents to the store
|
||||
- Media upload endpoints for submission images/videos
|
||||
|
||||
### model.py
|
||||
Contains Pydantic models for request/response validation and serialization:
|
||||
|
||||
- Pagination model for paginated responses
|
||||
- Models for agents, creators, profiles, submissions
|
||||
- Request/response models for all API endpoints
|
||||
|
||||
### db.py
|
||||
Contains database access functions using Prisma ORM:
|
||||
|
||||
- Functions to query and manipulate store data
|
||||
- Handles database operations for all API endpoints
|
||||
- Implements business logic and data validation
|
||||
|
||||
### media.py
|
||||
Handles media file uploads to Google Cloud Storage:
|
||||
|
||||
- Validates file types and sizes
|
||||
- Processes image and video uploads
|
||||
- Stores files in GCS buckets
|
||||
- Returns public URLs for uploaded media
|
||||
|
||||
## Key Features
|
||||
|
||||
- Paginated listings of store agents and creators
|
||||
- Search and filtering of agents and creators
|
||||
- Agent submission workflow
|
||||
- Media file upload handling
|
||||
- Profile management
|
||||
- Reviews and ratings
|
||||
|
||||
## Authentication
|
||||
|
||||
Most endpoints require authentication via the AutoGPT auth middleware. Public endpoints are marked with the "public" tag.
|
||||
|
||||
## Error Handling
|
||||
|
||||
All database and storage operations include proper error handling and logging. Errors are mapped to appropriate HTTP status codes.
|
||||
@@ -1,18 +1,13 @@
|
||||
import logging
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import fastapi
|
||||
import random
|
||||
import prisma.enums
|
||||
import prisma.errors
|
||||
import prisma.models
|
||||
import prisma.types
|
||||
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.exceptions
|
||||
import backend.server.v2.store.model
|
||||
from backend.data.graph import GraphModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,29 +24,6 @@ async def get_store_agents(
|
||||
logger.debug(
|
||||
f"Getting store agents. featured={featured}, creator={creator}, sorted_by={sorted_by}, search={search_query}, category={category}, page={page}"
|
||||
)
|
||||
sanitized_query = None
|
||||
# Sanitize and validate search query by escaping special characters
|
||||
if search_query is not None:
|
||||
sanitized_query = search_query.strip()
|
||||
if not sanitized_query or len(sanitized_query) > 100: # Reasonable length limit
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
f"Invalid search query: len({len(sanitized_query)}) query: {search_query}"
|
||||
)
|
||||
|
||||
# Escape special SQL characters
|
||||
sanitized_query = (
|
||||
sanitized_query.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
.replace("[", "\\[")
|
||||
.replace("]", "\\]")
|
||||
.replace("'", "\\'")
|
||||
.replace('"', '\\"')
|
||||
.replace(";", "\\;")
|
||||
.replace("--", "\\--")
|
||||
.replace("/*", "\\/*")
|
||||
.replace("*/", "\\*/")
|
||||
)
|
||||
|
||||
where_clause = {}
|
||||
if featured:
|
||||
@@ -60,11 +32,10 @@ async def get_store_agents(
|
||||
where_clause["creator_username"] = creator
|
||||
if category:
|
||||
where_clause["categories"] = {"has": category}
|
||||
|
||||
if sanitized_query:
|
||||
if search_query:
|
||||
where_clause["OR"] = [
|
||||
{"agent_name": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"agent_name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
order_by = []
|
||||
@@ -173,70 +144,40 @@ async def get_store_creators(
|
||||
f"Getting store creators. featured={featured}, search={search_query}, sorted_by={sorted_by}, page={page}"
|
||||
)
|
||||
|
||||
# Build where clause with sanitized inputs
|
||||
# Build where clause
|
||||
where = {}
|
||||
|
||||
if featured:
|
||||
where["is_featured"] = featured
|
||||
|
||||
# Add search filter if provided, using parameterized queries
|
||||
# Add search filter if provided
|
||||
if search_query:
|
||||
# Sanitize and validate search query by escaping special characters
|
||||
sanitized_query = search_query.strip()
|
||||
if not sanitized_query or len(sanitized_query) > 100: # Reasonable length limit
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Invalid search query"
|
||||
)
|
||||
|
||||
# Escape special SQL characters
|
||||
sanitized_query = (
|
||||
sanitized_query.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
.replace("[", "\\[")
|
||||
.replace("]", "\\]")
|
||||
.replace("'", "\\'")
|
||||
.replace('"', '\\"')
|
||||
.replace(";", "\\;")
|
||||
.replace("--", "\\--")
|
||||
.replace("/*", "\\/*")
|
||||
.replace("*/", "\\*/")
|
||||
)
|
||||
|
||||
where["OR"] = [
|
||||
{"username": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"name": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": sanitized_query, "mode": "insensitive"}},
|
||||
{"username": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"name": {"contains": search_query, "mode": "insensitive"}},
|
||||
{"description": {"contains": search_query, "mode": "insensitive"}},
|
||||
]
|
||||
|
||||
try:
|
||||
# Validate pagination parameters
|
||||
if not isinstance(page, int) or page < 1:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Invalid page number"
|
||||
)
|
||||
if not isinstance(page_size, int) or page_size < 1 or page_size > 100:
|
||||
raise backend.server.v2.store.exceptions.DatabaseError("Invalid page size")
|
||||
|
||||
# Get total count for pagination using sanitized where clause
|
||||
# Get total count for pagination
|
||||
total = await prisma.models.Creator.prisma().count(
|
||||
where=prisma.types.CreatorWhereInput(**where)
|
||||
)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
|
||||
# Add pagination with validated parameters
|
||||
# Add pagination
|
||||
skip = (page - 1) * page_size
|
||||
take = page_size
|
||||
|
||||
# Add sorting with validated sort parameter
|
||||
# Add sorting
|
||||
order = []
|
||||
valid_sort_fields = {"agent_rating", "agent_runs", "num_agents"}
|
||||
if sorted_by in valid_sort_fields:
|
||||
order.append({sorted_by: "desc"})
|
||||
if sorted_by == "agent_rating":
|
||||
order.append({"agent_rating": "desc"})
|
||||
elif sorted_by == "agent_runs":
|
||||
order.append({"agent_runs": "desc"})
|
||||
elif sorted_by == "num_agents":
|
||||
order.append({"num_agents": "desc"})
|
||||
else:
|
||||
order.append({"username": "asc"})
|
||||
|
||||
# Execute query with sanitized parameters
|
||||
# Execute query
|
||||
creators = await prisma.models.Creator.prisma().find_many(
|
||||
where=prisma.types.CreatorWhereInput(**where),
|
||||
skip=skip,
|
||||
@@ -254,7 +195,6 @@ async def get_store_creators(
|
||||
num_agents=creator.num_agents,
|
||||
agent_rating=creator.agent_rating,
|
||||
agent_runs=creator.agent_runs,
|
||||
is_featured=creator.is_featured,
|
||||
)
|
||||
for creator in creators
|
||||
]
|
||||
@@ -453,11 +393,6 @@ async def create_store_submission(
|
||||
)
|
||||
|
||||
try:
|
||||
# Sanitize slug to only allow letters and hyphens
|
||||
slug = "".join(
|
||||
c if c.isalpha() or c == "-" or c.isnumeric() else "" for c in slug
|
||||
).lower()
|
||||
|
||||
# First verify the agent belongs to this user
|
||||
agent = await prisma.models.AgentGraph.prisma().find_first(
|
||||
where=prisma.types.AgentGraphWhereInput(
|
||||
@@ -587,22 +522,22 @@ async def get_user_profile(
|
||||
|
||||
if not profile:
|
||||
logger.warning(f"Profile not found for user {user_id}")
|
||||
new_profile = await prisma.models.Profile.prisma().create(
|
||||
await prisma.models.Profile.prisma().create(
|
||||
data=prisma.types.ProfileCreateInput(
|
||||
userId=user_id,
|
||||
name="No Profile Data",
|
||||
username=f"{random.choice(['happy', 'clever', 'swift', 'bright', 'wise'])}-{random.choice(['fox', 'wolf', 'bear', 'eagle', 'owl'])}_{random.randint(1000,9999)}".lower(),
|
||||
username=f"{random.choice(['happy', 'clever', 'swift', 'bright', 'wise'])}-{random.choice(['fox', 'wolf', 'bear', 'eagle', 'owl'])}_{random.randint(1000,9999)}",
|
||||
description="No Profile Data",
|
||||
links=[],
|
||||
avatarUrl="",
|
||||
)
|
||||
)
|
||||
return backend.server.v2.store.model.ProfileDetails(
|
||||
name=new_profile.name,
|
||||
username=new_profile.username,
|
||||
description=new_profile.description,
|
||||
links=new_profile.links,
|
||||
avatar_url=new_profile.avatarUrl,
|
||||
name="No Profile Data",
|
||||
username="No Profile Data",
|
||||
description="No Profile Data",
|
||||
links=[],
|
||||
avatar_url="",
|
||||
)
|
||||
|
||||
return backend.server.v2.store.model.ProfileDetails(
|
||||
@@ -629,7 +564,6 @@ async def update_or_create_profile(
|
||||
"""
|
||||
Update the store profile for a user. Creates a new profile if one doesn't exist.
|
||||
Only allows updating if the user_id matches the owning user.
|
||||
If a field is None, it will not overwrite the existing value in the case of an update.
|
||||
|
||||
Args:
|
||||
user_id: ID of the authenticated user
|
||||
@@ -640,36 +574,27 @@ async def update_or_create_profile(
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not authorized to update this profile
|
||||
DatabaseError: If profile cannot be updated due to database issues
|
||||
"""
|
||||
logger.info(f"Updating profile for user {user_id} data: {profile}")
|
||||
logger.debug(f"Updating profile for user {user_id}")
|
||||
|
||||
try:
|
||||
# Sanitize username to only allow letters and hyphens
|
||||
username = "".join(
|
||||
c if c.isalpha() or c == "-" or c.isnumeric() else ""
|
||||
for c in profile.username
|
||||
).lower()
|
||||
|
||||
# Check if profile exists for user
|
||||
existing_profile = await prisma.models.Profile.prisma().find_first(
|
||||
where={"userId": user_id}
|
||||
)
|
||||
|
||||
# If no profile exists, create a new one
|
||||
if not existing_profile:
|
||||
logger.debug(
|
||||
f"No existing profile found. Creating new profile for user {user_id}"
|
||||
)
|
||||
logger.debug(f"Creating new profile for user {user_id}")
|
||||
# Create new profile since one doesn't exist
|
||||
new_profile = await prisma.models.Profile.prisma().create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"name": profile.name,
|
||||
"username": username,
|
||||
"username": profile.username,
|
||||
"description": profile.description,
|
||||
"links": profile.links or [],
|
||||
"links": profile.links,
|
||||
"avatarUrl": profile.avatar_url,
|
||||
"isFeatured": False,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -685,23 +610,16 @@ async def update_or_create_profile(
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Updating existing profile for user {user_id}")
|
||||
# Update only provided fields for the existing profile
|
||||
update_data = {}
|
||||
if profile.name is not None:
|
||||
update_data["name"] = profile.name
|
||||
if profile.username is not None:
|
||||
update_data["username"] = username
|
||||
if profile.description is not None:
|
||||
update_data["description"] = profile.description
|
||||
if profile.links is not None:
|
||||
update_data["links"] = profile.links
|
||||
if profile.avatar_url is not None:
|
||||
update_data["avatarUrl"] = profile.avatar_url
|
||||
|
||||
# Update the existing profile
|
||||
updated_profile = await prisma.models.Profile.prisma().update(
|
||||
where={"id": existing_profile.id},
|
||||
data=prisma.types.ProfileUpdateInput(**update_data),
|
||||
data=prisma.types.ProfileUpdateInput(
|
||||
name=profile.name,
|
||||
username=profile.username,
|
||||
description=profile.description,
|
||||
links=profile.links,
|
||||
avatarUrl=profile.avatar_url,
|
||||
),
|
||||
)
|
||||
if updated_profile is None:
|
||||
logger.error(f"Failed to update profile for user {user_id}")
|
||||
@@ -771,7 +689,6 @@ async def get_my_agents(
|
||||
agent_version=agent.version,
|
||||
agent_name=agent.name or "",
|
||||
last_edited=agent.updatedAt or agent.createdAt,
|
||||
description=agent.description or "",
|
||||
)
|
||||
for agent in agents
|
||||
]
|
||||
@@ -790,45 +707,3 @@ async def get_my_agents(
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch my agents"
|
||||
) from e
|
||||
|
||||
|
||||
async def get_agent(
|
||||
store_listing_version_id: str, version_id: Optional[int]
|
||||
) -> GraphModel:
|
||||
"""Get agent using the version ID and store listing version ID."""
|
||||
try:
|
||||
store_listing_version = (
|
||||
await prisma.models.StoreListingVersion.prisma().find_unique(
|
||||
where={"id": store_listing_version_id}, include={"Agent": True}
|
||||
)
|
||||
)
|
||||
|
||||
if not store_listing_version or not store_listing_version.Agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Store listing version {store_listing_version_id} not found",
|
||||
)
|
||||
|
||||
agent = store_listing_version.Agent
|
||||
|
||||
graph = await backend.data.graph.get_graph(
|
||||
agent.id, agent.version, template=True
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent {agent.id} not found"
|
||||
)
|
||||
|
||||
graph.version = 1
|
||||
graph.is_template = False
|
||||
graph.is_active = True
|
||||
delattr(graph, "user_id")
|
||||
|
||||
return graph
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting agent: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.DatabaseError(
|
||||
"Failed to fetch agent"
|
||||
) from e
|
||||
|
||||
@@ -113,7 +113,6 @@ async def test_get_store_creator_details(mocker):
|
||||
agent_rating=4.5,
|
||||
agent_runs=10,
|
||||
top_categories=["test"],
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
# Mock prisma call
|
||||
@@ -198,7 +197,6 @@ async def test_update_profile(mocker):
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatarUrl="avatar.jpg",
|
||||
isFeatured=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
@@ -217,7 +215,6 @@ async def test_update_profile(mocker):
|
||||
description="Test description",
|
||||
links=["link1"],
|
||||
avatar_url="avatar.jpg",
|
||||
is_featured=False,
|
||||
)
|
||||
|
||||
# Call function
|
||||
@@ -242,7 +239,6 @@ async def test_get_user_profile(mocker):
|
||||
description="Test description",
|
||||
links=["link1", "link2"],
|
||||
avatarUrl="avatar.jpg",
|
||||
isFeatured=False,
|
||||
createdAt=datetime.now(),
|
||||
updatedAt=datetime.now(),
|
||||
)
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
import io
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
import replicate
|
||||
import replicate.exceptions
|
||||
import requests
|
||||
from replicate.helpers import FileOutput
|
||||
|
||||
from backend.data.graph import Graph
|
||||
from backend.util.settings import Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageSize(str, Enum):
|
||||
LANDSCAPE = "1024x768"
|
||||
|
||||
|
||||
class ImageStyle(str, Enum):
|
||||
DIGITAL_ART = "digital art"
|
||||
|
||||
|
||||
async def generate_agent_image(agent: Graph) -> io.BytesIO:
|
||||
"""
|
||||
Generate an image for an agent using Flux model via Replicate API.
|
||||
|
||||
Args:
|
||||
agent (Graph): The agent to generate an image for
|
||||
|
||||
Returns:
|
||||
io.BytesIO: The generated image as bytes
|
||||
"""
|
||||
try:
|
||||
settings = Settings()
|
||||
|
||||
if not settings.secrets.replicate_api_key:
|
||||
raise ValueError("Missing Replicate API key in settings")
|
||||
|
||||
# Construct prompt from agent details
|
||||
prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design."
|
||||
|
||||
# Set up Replicate client
|
||||
client = replicate.Client(api_token=settings.secrets.replicate_api_key)
|
||||
|
||||
# Model parameters
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
"width": 1024,
|
||||
"height": 768,
|
||||
"aspect_ratio": "4:3",
|
||||
"output_format": "jpg",
|
||||
"output_quality": 90,
|
||||
"num_inference_steps": 30,
|
||||
"guidance": 3.5,
|
||||
"negative_prompt": "blurry, low quality, distorted, deformed",
|
||||
"disable_safety_checker": True,
|
||||
}
|
||||
|
||||
try:
|
||||
# Run model
|
||||
output = client.run("black-forest-labs/flux-1.1-pro", input=input_data)
|
||||
|
||||
# Depending on the model output, extract the image URL or bytes
|
||||
# If the output is a list of FileOutput or URLs
|
||||
if isinstance(output, list) and output:
|
||||
if isinstance(output[0], FileOutput):
|
||||
image_bytes = output[0].read()
|
||||
else:
|
||||
# If it's a URL string, fetch the image bytes
|
||||
result_url = output[0]
|
||||
response = requests.get(result_url)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
elif isinstance(output, FileOutput):
|
||||
image_bytes = output.read()
|
||||
elif isinstance(output, str):
|
||||
# Output is a URL
|
||||
response = requests.get(output)
|
||||
response.raise_for_status()
|
||||
image_bytes = response.content
|
||||
else:
|
||||
raise RuntimeError("Unexpected output format from the model.")
|
||||
|
||||
return io.BytesIO(image_bytes)
|
||||
|
||||
except replicate.exceptions.ReplicateError as e:
|
||||
if e.status == 401:
|
||||
raise RuntimeError("Invalid Replicate API token") from e
|
||||
raise RuntimeError(f"Replicate API error: {str(e)}") from e
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate agent image")
|
||||
raise RuntimeError(f"Image generation failed: {str(e)}")
|
||||
@@ -15,116 +15,22 @@ ALLOWED_VIDEO_TYPES = {"video/mp4", "video/webm"}
|
||||
MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
|
||||
|
||||
|
||||
async def check_media_exists(user_id: str, filename: str) -> str | None:
|
||||
"""
|
||||
Check if a media file exists in storage for the given user.
|
||||
Tries both images and videos directories.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user who uploaded the file
|
||||
filename (str): Name of the file to check
|
||||
|
||||
Returns:
|
||||
str | None: URL of the blob if it exists, None otherwise
|
||||
"""
|
||||
try:
|
||||
settings = Settings()
|
||||
storage_client = storage.Client()
|
||||
bucket = storage_client.bucket(settings.config.media_gcs_bucket_name)
|
||||
|
||||
# Check images
|
||||
image_path = f"users/{user_id}/images/{filename}"
|
||||
image_blob = bucket.blob(image_path)
|
||||
if image_blob.exists():
|
||||
return image_blob.public_url
|
||||
|
||||
# Check videos
|
||||
video_path = f"users/{user_id}/videos/{filename}"
|
||||
|
||||
video_blob = bucket.blob(video_path)
|
||||
if video_blob.exists():
|
||||
return video_blob.public_url
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if media file exists: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def upload_media(
|
||||
user_id: str, file: fastapi.UploadFile, use_file_name: bool = False
|
||||
) -> str:
|
||||
|
||||
# Get file content for deeper validation
|
||||
try:
|
||||
content = await file.read(1024) # Read first 1KB for validation
|
||||
await file.seek(0) # Reset file pointer
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file content: {str(e)}")
|
||||
raise backend.server.v2.store.exceptions.FileReadError(
|
||||
"Failed to read file content"
|
||||
) from e
|
||||
|
||||
# Validate file signature/magic bytes
|
||||
if file.content_type in ALLOWED_IMAGE_TYPES:
|
||||
# Check image file signatures
|
||||
if content.startswith(b"\xFF\xD8\xFF"): # JPEG
|
||||
if file.content_type != "image/jpeg":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"\x89PNG\r\n\x1a\n"): # PNG
|
||||
if file.content_type != "image/png":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"GIF87a") or content.startswith(b"GIF89a"): # GIF
|
||||
if file.content_type != "image/gif":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"RIFF") and content[8:12] == b"WEBP": # WebP
|
||||
if file.content_type != "image/webp":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
else:
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"Invalid image file signature"
|
||||
)
|
||||
|
||||
elif file.content_type in ALLOWED_VIDEO_TYPES:
|
||||
# Check video file signatures
|
||||
if content.startswith(b"\x00\x00\x00") and (content[4:8] == b"ftyp"): # MP4
|
||||
if file.content_type != "video/mp4":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
elif content.startswith(b"\x1a\x45\xdf\xa3"): # WebM
|
||||
if file.content_type != "video/webm":
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"File signature does not match content type"
|
||||
)
|
||||
else:
|
||||
raise backend.server.v2.store.exceptions.InvalidFileTypeError(
|
||||
"Invalid video file signature"
|
||||
)
|
||||
|
||||
async def upload_media(user_id: str, file: fastapi.UploadFile) -> str:
|
||||
settings = Settings()
|
||||
|
||||
# Check required settings first before doing any file processing
|
||||
if not settings.config.media_gcs_bucket_name:
|
||||
logger.error("Missing GCS bucket name setting")
|
||||
if (
|
||||
not settings.config.media_gcs_bucket_name
|
||||
or not settings.config.google_application_credentials
|
||||
):
|
||||
logger.error("Missing required GCS settings")
|
||||
raise backend.server.v2.store.exceptions.StorageConfigError(
|
||||
"Missing storage bucket configuration"
|
||||
"Missing storage configuration"
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate file type
|
||||
content_type = file.content_type
|
||||
if content_type is None:
|
||||
content_type = "image/jpeg"
|
||||
|
||||
if (
|
||||
content_type not in ALLOWED_IMAGE_TYPES
|
||||
and content_type not in ALLOWED_VIDEO_TYPES
|
||||
@@ -160,10 +66,7 @@ async def upload_media(
|
||||
# Generate unique filename
|
||||
filename = file.filename or ""
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
if use_file_name:
|
||||
unique_filename = filename
|
||||
else:
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
unique_filename = f"{uuid.uuid4()}{file_ext}"
|
||||
|
||||
# Construct storage path
|
||||
media_type = "images" if content_type in ALLOWED_IMAGE_TYPES else "videos"
|
||||
|
||||
@@ -27,7 +27,7 @@ def mock_storage_client(mocker):
|
||||
|
||||
mock_client.bucket.return_value = mock_bucket
|
||||
mock_bucket.blob.return_value = mock_blob
|
||||
mock_blob.public_url = "http://test-url/media/laptop.jpeg"
|
||||
mock_blob.public_url = "http://test-url/media/test.jpg"
|
||||
|
||||
mocker.patch("google.cloud.storage.Client", return_value=mock_client)
|
||||
|
||||
@@ -35,18 +35,15 @@ def mock_storage_client(mocker):
|
||||
|
||||
|
||||
async def test_upload_media_success(mock_settings, mock_storage_client):
|
||||
# Create test JPEG data with valid signature
|
||||
test_data = b"\xFF\xD8\xFF" + b"test data"
|
||||
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
file=io.BytesIO(test_data),
|
||||
filename="test.jpeg",
|
||||
file=io.BytesIO(b"test data"),
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
assert result == "http://test-url/media/test.jpg"
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_called_once()
|
||||
@@ -74,8 +71,8 @@ async def test_upload_media_missing_credentials(monkeypatch):
|
||||
monkeypatch.setattr("backend.server.v2.store.media.Settings", lambda: settings)
|
||||
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
file=io.BytesIO(b"\xFF\xD8\xFF" + b"test data"), # Valid JPEG signature
|
||||
filename="test.jpeg",
|
||||
file=io.BytesIO(b"test data"),
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
@@ -86,105 +83,25 @@ async def test_upload_media_missing_credentials(monkeypatch):
|
||||
async def test_upload_media_video_type(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.mp4",
|
||||
file=io.BytesIO(b"\x00\x00\x00\x18ftypmp42"), # Valid MP4 signature
|
||||
file=io.BytesIO(b"test video data"),
|
||||
headers=starlette.datastructures.Headers({"content-type": "video/mp4"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
assert result == "http://test-url/media/test.jpg"
|
||||
mock_bucket = mock_storage_client.bucket.return_value
|
||||
mock_blob = mock_bucket.blob.return_value
|
||||
mock_blob.upload_from_string.assert_called_once()
|
||||
|
||||
|
||||
async def test_upload_media_file_too_large(mock_settings, mock_storage_client):
|
||||
large_data = b"\xFF\xD8\xFF" + b"x" * (
|
||||
50 * 1024 * 1024 + 1
|
||||
) # 50MB + 1 byte with valid JPEG signature
|
||||
large_data = b"x" * (50 * 1024 * 1024 + 1) # 50MB + 1 byte
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
filename="test.jpeg",
|
||||
file=io.BytesIO(large_data),
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.FileSizeTooLargeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_file_read_error(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="laptop.jpeg",
|
||||
file=io.BytesIO(b""), # Empty file that will raise error on read
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
test_file.read = unittest.mock.AsyncMock(side_effect=Exception("Read error"))
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.FileReadError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_png_success(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.png",
|
||||
file=io.BytesIO(b"\x89PNG\r\n\x1a\n"), # Valid PNG signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/png"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_gif_success(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.gif",
|
||||
file=io.BytesIO(b"GIF89a"), # Valid GIF signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/gif"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_webp_success(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.webp",
|
||||
file=io.BytesIO(b"RIFF\x00\x00\x00\x00WEBP"), # Valid WebP signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/webp"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_webm_success(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.webm",
|
||||
file=io.BytesIO(b"\x1a\x45\xdf\xa3"), # Valid WebM signature
|
||||
headers=starlette.datastructures.Headers({"content-type": "video/webm"}),
|
||||
)
|
||||
|
||||
result = await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
assert result == "http://test-url/media/laptop.jpeg"
|
||||
|
||||
|
||||
async def test_upload_media_mismatched_signature(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.jpeg",
|
||||
file=io.BytesIO(b"\x89PNG\r\n\x1a\n"), # PNG signature with JPEG content type
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
|
||||
async def test_upload_media_invalid_signature(mock_settings, mock_storage_client):
|
||||
test_file = fastapi.UploadFile(
|
||||
filename="test.jpeg",
|
||||
file=io.BytesIO(b"invalid signature"),
|
||||
headers=starlette.datastructures.Headers({"content-type": "image/jpeg"}),
|
||||
)
|
||||
|
||||
with pytest.raises(backend.server.v2.store.exceptions.InvalidFileTypeError):
|
||||
await backend.server.v2.store.media.upload_media("test-user", test_file)
|
||||
|
||||
@@ -24,7 +24,6 @@ class MyAgent(pydantic.BaseModel):
|
||||
agent_id: str
|
||||
agent_version: int
|
||||
agent_name: str
|
||||
description: str
|
||||
last_edited: datetime.datetime
|
||||
|
||||
|
||||
@@ -75,7 +74,6 @@ class Creator(pydantic.BaseModel):
|
||||
num_agents: int
|
||||
agent_rating: float
|
||||
agent_runs: int
|
||||
is_featured: bool
|
||||
|
||||
|
||||
class CreatorsResponse(pydantic.BaseModel):
|
||||
@@ -100,7 +98,6 @@ class Profile(pydantic.BaseModel):
|
||||
description: str
|
||||
links: list[str]
|
||||
avatar_url: str
|
||||
is_featured: bool = False
|
||||
|
||||
|
||||
class StoreSubmission(pydantic.BaseModel):
|
||||
|
||||
@@ -88,7 +88,6 @@ def test_creator():
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
assert creator.name == "Test Creator"
|
||||
assert creator.num_agents == 5
|
||||
@@ -105,7 +104,6 @@ def test_creators_response():
|
||||
description="Test description",
|
||||
avatar_url="avatar.jpg",
|
||||
num_agents=5,
|
||||
is_featured=False,
|
||||
)
|
||||
],
|
||||
pagination=backend.server.v2.store.model.Pagination(
|
||||
|
||||
@@ -1,19 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
import typing
|
||||
import urllib.parse
|
||||
|
||||
import autogpt_libs.auth.depends
|
||||
import autogpt_libs.auth.middleware
|
||||
import fastapi
|
||||
import fastapi.responses
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
||||
import backend.data.block
|
||||
import backend.data.graph
|
||||
import backend.server.v2.store.db
|
||||
import backend.server.v2.store.image_gen
|
||||
import backend.server.v2.store.media
|
||||
import backend.server.v2.store.model
|
||||
|
||||
@@ -27,16 +20,12 @@ router = fastapi.APIRouter()
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/profile",
|
||||
tags=["store", "private"],
|
||||
response_model=backend.server.v2.store.model.ProfileDetails,
|
||||
)
|
||||
@router.get("/profile", tags=["store", "private"])
|
||||
async def get_profile(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
):
|
||||
) -> backend.server.v2.store.model.ProfileDetails:
|
||||
"""
|
||||
Get the profile details for the authenticated user.
|
||||
"""
|
||||
@@ -45,24 +34,20 @@ async def get_profile(
|
||||
return profile
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting user profile")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the user profile"},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/profile",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.CreatorDetails,
|
||||
)
|
||||
async def update_or_create_profile(
|
||||
profile: backend.server.v2.store.model.Profile,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
):
|
||||
) -> backend.server.v2.store.model.CreatorDetails:
|
||||
"""
|
||||
Update the store profile for the authenticated user.
|
||||
|
||||
@@ -83,10 +68,7 @@ async def update_or_create_profile(
|
||||
return updated_profile
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst updating profile")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while updating the user profile"},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
##############################################
|
||||
@@ -94,11 +76,7 @@ async def update_or_create_profile(
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.StoreAgentsResponse,
|
||||
)
|
||||
@router.get("/agents", tags=["store", "public"])
|
||||
async def get_agents(
|
||||
featured: bool = False,
|
||||
creator: str | None = None,
|
||||
@@ -107,7 +85,7 @@ async def get_agents(
|
||||
category: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
) -> backend.server.v2.store.model.StoreAgentsResponse:
|
||||
"""
|
||||
Get a paginated list of agents from the store with optional filtering and sorting.
|
||||
|
||||
@@ -157,46 +135,32 @@ async def get_agents(
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occured whilst getting store agents")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the store agents"},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/agents/{username}/{agent_name}",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.StoreAgentDetails,
|
||||
)
|
||||
async def get_agent(username: str, agent_name: str):
|
||||
@router.get("/agents/{username}/{agent_name}", tags=["store", "public"])
|
||||
async def get_agent(
|
||||
username: str, agent_name: str
|
||||
) -> backend.server.v2.store.model.StoreAgentDetails:
|
||||
"""
|
||||
This is only used on the AgentDetails Page
|
||||
|
||||
It returns the store listing agents details.
|
||||
"""
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
# URL decode the agent name since it comes from the URL path
|
||||
agent_name = urllib.parse.unquote(agent_name).lower()
|
||||
agent = await backend.server.v2.store.db.get_store_agent_details(
|
||||
username=username, agent_name=agent_name
|
||||
)
|
||||
return agent
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store agent details")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving the store agent details"
|
||||
},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/agents/{username}/{agent_name}/review",
|
||||
tags=["store"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.StoreReview,
|
||||
)
|
||||
async def create_review(
|
||||
username: str,
|
||||
@@ -205,7 +169,7 @@ async def create_review(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
):
|
||||
) -> backend.server.v2.store.model.StoreReview:
|
||||
"""
|
||||
Create a review for a store agent.
|
||||
|
||||
@@ -219,8 +183,6 @@ async def create_review(
|
||||
The created review
|
||||
"""
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
agent_name = urllib.parse.unquote(agent_name)
|
||||
# Create the review
|
||||
created_review = await backend.server.v2.store.db.create_store_review(
|
||||
user_id=user_id,
|
||||
@@ -232,10 +194,7 @@ async def create_review(
|
||||
return created_review
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store review")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while creating the store review"},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
##############################################
|
||||
@@ -243,18 +202,14 @@ async def create_review(
|
||||
##############################################
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creators",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.CreatorsResponse,
|
||||
)
|
||||
@router.get("/creators", tags=["store", "public"])
|
||||
async def get_creators(
|
||||
featured: bool = False,
|
||||
search_query: str | None = None,
|
||||
sorted_by: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
) -> backend.server.v2.store.model.CreatorsResponse:
|
||||
"""
|
||||
This is needed for:
|
||||
- Home Page Featured Creators
|
||||
@@ -288,38 +243,23 @@ async def get_creators(
|
||||
return creators
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store creators")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the store creators"},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/creator/{username}",
|
||||
tags=["store", "public"],
|
||||
response_model=backend.server.v2.store.model.CreatorDetails,
|
||||
)
|
||||
async def get_creator(
|
||||
username: str,
|
||||
):
|
||||
@router.get("/creator/{username}", tags=["store", "public"])
|
||||
async def get_creator(username: str) -> backend.server.v2.store.model.CreatorDetails:
|
||||
"""
|
||||
Get the details of a creator
|
||||
- Creator Details Page
|
||||
"""
|
||||
try:
|
||||
username = urllib.parse.unquote(username).lower()
|
||||
creator = await backend.server.v2.store.db.get_store_creator_details(
|
||||
username=username.lower()
|
||||
username=username
|
||||
)
|
||||
return creator
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting creator details")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving the creator details"
|
||||
},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
############################################
|
||||
@@ -329,36 +269,31 @@ async def get_creator(
|
||||
"/myagents",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.MyAgentsResponse,
|
||||
)
|
||||
async def get_my_agents(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
]
|
||||
):
|
||||
) -> backend.server.v2.store.model.MyAgentsResponse:
|
||||
try:
|
||||
agents = await backend.server.v2.store.db.get_my_agents(user_id)
|
||||
return agents
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting my agents")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while retrieving the my agents"},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/submissions/{submission_id}",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=bool,
|
||||
)
|
||||
async def delete_submission(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
submission_id: str,
|
||||
):
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a store listing submission.
|
||||
|
||||
@@ -377,17 +312,13 @@ async def delete_submission(
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst deleting store submission")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while deleting the store submission"},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.get(
|
||||
"/submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.StoreSubmissionsResponse,
|
||||
)
|
||||
async def get_submissions(
|
||||
user_id: typing.Annotated[
|
||||
@@ -395,7 +326,7 @@ async def get_submissions(
|
||||
],
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
):
|
||||
) -> backend.server.v2.store.model.StoreSubmissionsResponse:
|
||||
"""
|
||||
Get a paginated list of store submissions for the authenticated user.
|
||||
|
||||
@@ -428,26 +359,20 @@ async def get_submissions(
|
||||
return listings
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst getting store submissions")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"detail": "An error occurred while retrieving the store submissions"
|
||||
},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
response_model=backend.server.v2.store.model.StoreSubmission,
|
||||
)
|
||||
async def create_submission(
|
||||
submission_request: backend.server.v2.store.model.StoreSubmissionRequest,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
):
|
||||
) -> backend.server.v2.store.model.StoreSubmission:
|
||||
"""
|
||||
Create a new store listing submission.
|
||||
|
||||
@@ -477,10 +402,7 @@ async def create_submission(
|
||||
return submission
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst creating store submission")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while creating the store submission"},
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@router.post(
|
||||
@@ -493,7 +415,7 @@ async def upload_submission_media(
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
):
|
||||
) -> str:
|
||||
"""
|
||||
Upload media (images/videos) for a store listing submission.
|
||||
|
||||
@@ -514,131 +436,4 @@ async def upload_submission_media(
|
||||
return media_url
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst uploading submission media")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while uploading the media file"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/submissions/generate_image",
|
||||
tags=["store", "private"],
|
||||
dependencies=[fastapi.Depends(autogpt_libs.auth.middleware.auth_middleware)],
|
||||
)
|
||||
async def generate_image(
|
||||
agent_id: str,
|
||||
user_id: typing.Annotated[
|
||||
str, fastapi.Depends(autogpt_libs.auth.depends.get_user_id)
|
||||
],
|
||||
) -> fastapi.responses.Response:
|
||||
"""
|
||||
Generate an image for a store listing submission.
|
||||
|
||||
Args:
|
||||
agent_id (str): ID of the agent to generate an image for
|
||||
user_id (str): ID of the authenticated user
|
||||
|
||||
Returns:
|
||||
JSONResponse: JSON containing the URL of the generated image
|
||||
"""
|
||||
try:
|
||||
agent = await backend.data.graph.get_graph(agent_id, user_id=user_id)
|
||||
|
||||
if not agent:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=404, detail=f"Agent with ID {agent_id} not found"
|
||||
)
|
||||
# Use .jpeg here since we are generating JPEG images
|
||||
filename = f"agent_{agent_id}.jpeg"
|
||||
|
||||
existing_url = await backend.server.v2.store.media.check_media_exists(
|
||||
user_id, filename
|
||||
)
|
||||
if existing_url:
|
||||
logger.info(f"Using existing image for agent {agent_id}")
|
||||
return fastapi.responses.JSONResponse(content={"image_url": existing_url})
|
||||
# Generate agent image as JPEG
|
||||
image = await backend.server.v2.store.image_gen.generate_agent_image(
|
||||
agent=agent
|
||||
)
|
||||
|
||||
# Create UploadFile with the correct filename and content_type
|
||||
image_file = fastapi.UploadFile(
|
||||
file=image,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
image_url = await backend.server.v2.store.media.upload_media(
|
||||
user_id=user_id, file=image_file, use_file_name=True
|
||||
)
|
||||
|
||||
return fastapi.responses.JSONResponse(content={"image_url": image_url})
|
||||
except Exception:
|
||||
logger.exception("Exception occurred whilst generating submission image")
|
||||
return fastapi.responses.JSONResponse(
|
||||
status_code=500,
|
||||
content={"detail": "An error occurred while generating the image"},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/download/agents/{store_listing_version_id}",
|
||||
tags=["store", "public"],
|
||||
)
|
||||
async def download_agent_file(
|
||||
store_listing_version_id: str = fastapi.Path(
|
||||
..., description="The ID of the agent to download"
|
||||
),
|
||||
version: typing.Optional[int] = fastapi.Query(
|
||||
None, description="Specific version of the agent"
|
||||
),
|
||||
) -> fastapi.responses.FileResponse:
|
||||
"""
|
||||
Download the agent file by streaming its content.
|
||||
|
||||
Args:
|
||||
agent_id (str): The ID of the agent to download.
|
||||
version (Optional[int]): Specific version of the agent to download.
|
||||
|
||||
Returns:
|
||||
StreamingResponse: A streaming response containing the agent's graph data.
|
||||
|
||||
Raises:
|
||||
HTTPException: If the agent is not found or an unexpected error occurs.
|
||||
"""
|
||||
|
||||
graph_data = await backend.server.v2.store.db.get_agent(
|
||||
store_listing_version_id=store_listing_version_id, version_id=version
|
||||
)
|
||||
|
||||
graph_data.clean_graph()
|
||||
graph_date_dict = jsonable_encoder(graph_data)
|
||||
|
||||
def remove_credentials(obj):
|
||||
if obj and isinstance(obj, dict):
|
||||
if "credentials" in obj:
|
||||
del obj["credentials"]
|
||||
if "creds" in obj:
|
||||
del obj["creds"]
|
||||
|
||||
for value in obj.values():
|
||||
remove_credentials(value)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
remove_credentials(item)
|
||||
return obj
|
||||
|
||||
graph_date_dict = remove_credentials(graph_date_dict)
|
||||
|
||||
file_name = f"agent_{store_listing_version_id}_v{version or 'latest'}.json"
|
||||
|
||||
# Sending graph as a stream (similar to marketplace v1)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as tmp_file:
|
||||
tmp_file.write(json.dumps(graph_date_dict))
|
||||
tmp_file.flush()
|
||||
|
||||
return fastapi.responses.FileResponse(
|
||||
tmp_file.name, filename=file_name, media_type="application/json"
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -402,7 +402,6 @@ def test_get_creators_pagination(mocker: pytest_mock.MockFixture):
|
||||
num_agents=1,
|
||||
agent_rating=4.5,
|
||||
agent_runs=100,
|
||||
is_featured=False,
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
|
||||
@@ -38,7 +38,7 @@ def create_test_graph() -> graph.Graph:
|
||||
graph.Node(
|
||||
block_id=FillTextTemplateBlock().id,
|
||||
input_default={
|
||||
"format": "{{a}}, {{b}}{{c}}",
|
||||
"format": "{a}, {b}{c}",
|
||||
"values_#_c": "!!!",
|
||||
},
|
||||
),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user