mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-12 08:38:09 -05:00
Compare commits
195 Commits
benchmark/
...
remove-git
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
94f0cfd38a | ||
|
|
08c32a7a12 | ||
|
|
56104bd047 | ||
|
|
2ef5cd7d4c | ||
|
|
74b3aae5c6 | ||
|
|
e9b3b5090c | ||
|
|
9bac6f4ce2 | ||
|
|
39c46ef6be | ||
|
|
78d83bb3ce | ||
|
|
d57ccf7ec9 | ||
|
|
ada2e19829 | ||
|
|
a7c7a5e18b | ||
|
|
180de0c9a9 | ||
|
|
8f0d5c73b3 | ||
|
|
3b00e8229c | ||
|
|
e97726cde3 | ||
|
|
d38e8b8f6c | ||
|
|
0014e2ac14 | ||
|
|
370615e5e4 | ||
|
|
f93c743d03 | ||
|
|
6add645597 | ||
|
|
bdda3a6698 | ||
|
|
126aacb2e3 | ||
|
|
1afc8e40df | ||
|
|
9543e5d6ac | ||
|
|
5e89b8c6d1 | ||
|
|
fd3f8fa5fc | ||
|
|
86bdbb82b1 | ||
|
|
898317c16c | ||
|
|
0704404344 | ||
|
|
a74548d3cd | ||
|
|
6ff02677d2 | ||
|
|
4db4ca08b2 | ||
|
|
7bb7c30842 | ||
|
|
35ebb10378 | ||
|
|
b77451bb3a | ||
|
|
cf00c33f90 | ||
|
|
61adf58f4f | ||
|
|
452df39a52 | ||
|
|
49a08ba7db | ||
|
|
7082e63b11 | ||
|
|
d7f00a996f | ||
|
|
cf033504c2 | ||
|
|
e866a4ba04 | ||
|
|
90f3c5e2d9 | ||
|
|
fb8ed0b46b | ||
|
|
12640f7092 | ||
|
|
5f9cc585b1 | ||
|
|
262771a69c | ||
|
|
a1ffe15142 | ||
|
|
30bc761391 | ||
|
|
2a0e087461 | ||
|
|
828b81e5ef | ||
|
|
fe3f835b3e | ||
|
|
6dd76afad5 | ||
|
|
20041d65bf | ||
|
|
028d2c319f | ||
|
|
9e39937072 | ||
|
|
07a3c1848c | ||
|
|
dde0c70a81 | ||
|
|
76d6e61941 | ||
|
|
bca50310f6 | ||
|
|
632686cfa5 | ||
|
|
1262b72f5c | ||
|
|
e985f7c105 | ||
|
|
596487b9ad | ||
|
|
a7c0440e9b | ||
|
|
03ffb50dcf | ||
|
|
e201f57861 | ||
|
|
fea62a77bc | ||
|
|
dfad535dea | ||
|
|
fa14865163 | ||
|
|
ef35028ecb | ||
|
|
fb97e15e4b | ||
|
|
da4f013a5d | ||
|
|
fd2c26188f | ||
|
|
89cf0154f4 | ||
|
|
cb1297ec74 | ||
|
|
37904a0f80 | ||
|
|
6c18627b0f | ||
|
|
d5aa8d373b | ||
|
|
7bf31dad35 | ||
|
|
29d390d54d | ||
|
|
b69f0b2cd0 | ||
|
|
0308fb45be | ||
|
|
0325370fed | ||
|
|
1e4bd0388f | ||
|
|
d1b06f0be3 | ||
|
|
3e40b35ef1 | ||
|
|
70873906b7 | ||
|
|
f93a8a93b4 | ||
|
|
4121d3712d | ||
|
|
4546dfdf17 | ||
|
|
4011294da0 | ||
|
|
48f6f83f05 | ||
|
|
51f5808430 | ||
|
|
695049bfa3 | ||
|
|
40f98f0f38 | ||
|
|
c26c79c34c | ||
|
|
2c96f6125f | ||
|
|
5047fd9fce | ||
|
|
50e5ea4e54 | ||
|
|
ce45c9b267 | ||
|
|
1881f4f7cd | ||
|
|
30762c211e | ||
|
|
5090f55eba | ||
|
|
1f1e8c9f7d | ||
|
|
e44ca4185a | ||
|
|
8fd2e48c1b | ||
|
|
69ccb185e8 | ||
|
|
a88e833831 | ||
|
|
64f48df62d | ||
|
|
0f5490075b | ||
|
|
d5f2bbf093 | ||
|
|
7dd97f2f74 | ||
|
|
8e464c53a8 | ||
|
|
7689a51f53 | ||
|
|
c8a40727d1 | ||
|
|
4ef912d734 | ||
|
|
49a6d68200 | ||
|
|
6cfe229332 | ||
|
|
1079d71699 | ||
|
|
e104427767 | ||
|
|
bfd479a50b | ||
|
|
fb63bf4425 | ||
|
|
3a17011129 | ||
|
|
c339c6b54f | ||
|
|
7f71d6d9fd | ||
|
|
784e2bbb1c | ||
|
|
959377f54c | ||
|
|
6bc83e925c | ||
|
|
4ede773f5a | ||
|
|
d5ad719757 | ||
|
|
1ca9b9fa93 | ||
|
|
15024fb5a1 | ||
|
|
fa4bdef17c | ||
|
|
e2b519ef3b | ||
|
|
09c307d679 | ||
|
|
880c8e804c | ||
|
|
5f0764b65c | ||
|
|
63e6014b27 | ||
|
|
83fcd9ad16 | ||
|
|
f9792ed7f3 | ||
|
|
d6ab470c58 | ||
|
|
666a5a8777 | ||
|
|
21f1e64559 | ||
|
|
752bac099b | ||
|
|
a5de79beb6 | ||
|
|
483c01b681 | ||
|
|
992b8874fc | ||
|
|
2a55efb322 | ||
|
|
23d58a3cc0 | ||
|
|
70e345b2ce | ||
|
|
650a701317 | ||
|
|
679339d00c | ||
|
|
fd5730b04a | ||
|
|
b7f08cd0f7 | ||
|
|
8762f7ab3d | ||
|
|
a9b7b175ff | ||
|
|
52b93dd84e | ||
|
|
6a09a44ef7 | ||
|
|
32a627eda9 | ||
|
|
67bafa6302 | ||
|
|
6017eefb32 | ||
|
|
ae197fc85f | ||
|
|
22aba6dd8a | ||
|
|
88bbdfc7fc | ||
|
|
d0c9b7c405 | ||
|
|
e7698a4610 | ||
|
|
ab05b7ae70 | ||
|
|
327fb1f916 | ||
|
|
bb7f5abc6c | ||
|
|
393d6b97e6 | ||
|
|
3b8d63dfb6 | ||
|
|
6763196d78 | ||
|
|
e1da58da02 | ||
|
|
91cec515d4 | ||
|
|
cc585a014f | ||
|
|
e641cccb42 | ||
|
|
cc73d4104b | ||
|
|
250552cb3d | ||
|
|
1d653973e9 | ||
|
|
7bf9ba5502 | ||
|
|
14c9773890 | ||
|
|
39fddb1214 | ||
|
|
fe0923ba6c | ||
|
|
dfaeda7cd5 | ||
|
|
9b7fee673e | ||
|
|
925269d17b | ||
|
|
266fe3a3f7 | ||
|
|
66e0c87894 | ||
|
|
55433f468a | ||
|
|
956cdc77fa | ||
|
|
83a0b03523 | ||
|
|
25b9e290a5 |
10
.github/CODEOWNERS
vendored
10
.github/CODEOWNERS
vendored
@@ -1,5 +1,5 @@
|
||||
.github/workflows/ @Significant-Gravitas/maintainers
|
||||
autogpts/autogpt/ @Pwuts
|
||||
benchmark/ @Significant-Gravitas/benchmarkers
|
||||
forge/ @Swiftyos
|
||||
frontend/ @hunteraraujo
|
||||
.github/workflows/ @Significant-Gravitas/devops
|
||||
autogpts/autogpt/ @Significant-Gravitas/maintainers
|
||||
autogpts/forge/ @Significant-Gravitas/forge-maintainers
|
||||
benchmark/ @Significant-Gravitas/benchmark-maintainers
|
||||
frontend/ @Significant-Gravitas/frontend-maintainers
|
||||
|
||||
23
.github/labeler.yml
vendored
Normal file
23
.github/labeler.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
AutoGPT Agent:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: autogpts/autogpt/**
|
||||
|
||||
Forge:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: autogpts/forge/**
|
||||
|
||||
Benchmark:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: benchmark/**
|
||||
|
||||
Frontend:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: frontend/**
|
||||
|
||||
Arena:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: arena/**
|
||||
|
||||
documentation:
|
||||
- changed-files:
|
||||
- any-glob-to-any-file: docs/**
|
||||
169
.github/workflows/arena-intake.yml
vendored
Normal file
169
.github/workflows/arena-intake.yml
vendored
Normal file
@@ -0,0 +1,169 @@
|
||||
name: Arena intake
|
||||
|
||||
on:
|
||||
# We recommend `pull_request_target` so that github secrets are available.
|
||||
# In `pull_request` we wouldn't be able to change labels of fork PRs
|
||||
pull_request_target:
|
||||
types: [ opened, synchronize ]
|
||||
paths:
|
||||
- 'arena/**'
|
||||
|
||||
jobs:
|
||||
check:
|
||||
permissions:
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout PR
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
- name: Check Arena entry
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
console.log('⚙️ Setting up...');
|
||||
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
|
||||
const pr = context.payload.pull_request;
|
||||
const isFork = pr.head.repo.fork;
|
||||
|
||||
console.log('🔄️ Fetching PR diff metadata...');
|
||||
const prFilesChanged = (await github.rest.pulls.listFiles({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: pr.number,
|
||||
})).data;
|
||||
console.debug(prFilesChanged);
|
||||
const arenaFilesChanged = prFilesChanged.filter(
|
||||
({ filename: file }) => file.startsWith('arena/') && file.endsWith('.json')
|
||||
);
|
||||
const hasChangesInAutogptsFolder = prFilesChanged.some(
|
||||
({ filename }) => filename.startsWith('autogpts/')
|
||||
);
|
||||
|
||||
console.log(`🗒️ ${arenaFilesChanged.length} arena entries affected`);
|
||||
console.debug(arenaFilesChanged);
|
||||
if (arenaFilesChanged.length === 0) {
|
||||
// If no files in `arena/` are changed, this job does not need to run.
|
||||
return;
|
||||
}
|
||||
|
||||
let close = false;
|
||||
let flagForManualCheck = false;
|
||||
let issues = [];
|
||||
|
||||
if (isFork) {
|
||||
if (arenaFilesChanged.length > 1) {
|
||||
// Impacting multiple entries in `arena/` is not allowed
|
||||
issues.push('This pull request impacts multiple arena entries');
|
||||
}
|
||||
if (hasChangesInAutogptsFolder) {
|
||||
// PRs that include the custom agent are generally not allowed
|
||||
issues.push(
|
||||
'This pull request includes changes in `autogpts/`.\n'
|
||||
+ 'Please make sure to only submit your arena entry (`arena/*.json`), '
|
||||
+ 'and not to accidentally include your custom agent itself.'
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (arenaFilesChanged.length === 1) {
|
||||
const newArenaFile = arenaFilesChanged[0]
|
||||
const newArenaFileName = path.basename(newArenaFile.filename)
|
||||
console.log(`🗒️ Arena entry in PR: ${newArenaFile}`);
|
||||
|
||||
if (newArenaFile.status != 'added') {
|
||||
flagForManualCheck = true;
|
||||
}
|
||||
|
||||
if (pr.mergeable != false) {
|
||||
const newArenaEntry = JSON.parse(fs.readFileSync(newArenaFile.filename));
|
||||
const allArenaFiles = await (await glob.create('arena/*.json')).glob();
|
||||
console.debug(newArenaEntry);
|
||||
|
||||
console.log(`➡️ Checking ${newArenaFileName} against existing entries...`);
|
||||
for (const file of allArenaFiles) {
|
||||
const existingEntryName = path.basename(file);
|
||||
|
||||
if (existingEntryName === newArenaFileName) {
|
||||
continue;
|
||||
}
|
||||
|
||||
console.debug(`Checking against ${existingEntryName}...`);
|
||||
|
||||
const arenaEntry = JSON.parse(fs.readFileSync(file));
|
||||
if (arenaEntry.github_repo_url === newArenaEntry.github_repo_url) {
|
||||
console.log(`⚠️ Duplicate detected: ${existingEntryName}`);
|
||||
issues.push(
|
||||
`The \`github_repo_url\` specified in __${newArenaFileName}__ `
|
||||
+ `already exists in __${existingEntryName}__. `
|
||||
+ `This PR will be closed as duplicate.`
|
||||
)
|
||||
close = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.log('⚠️ PR has conflicts');
|
||||
issues.push(
|
||||
`__${newArenaFileName}__ conflicts with existing entry with the same name`
|
||||
)
|
||||
close = true;
|
||||
}
|
||||
} // end if (arenaFilesChanged.length === 1)
|
||||
|
||||
console.log('🏁 Finished checking against existing entries');
|
||||
|
||||
if (issues.length == 0) {
|
||||
console.log('✅ No issues detected');
|
||||
if (flagForManualCheck) {
|
||||
console.log('🤔 Requesting review from maintainers...');
|
||||
await github.rest.pulls.requestReviewers({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: pr.number,
|
||||
reviewers: ['Pwuts'],
|
||||
// team_reviewers: ['maintainers'], // doesn't work: https://stackoverflow.com/a/64977184/4751645
|
||||
});
|
||||
} else {
|
||||
console.log('➡️ Approving PR...');
|
||||
await github.rest.pulls.createReview({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: pr.number,
|
||||
event: 'APPROVE',
|
||||
});
|
||||
}
|
||||
} else {
|
||||
console.log(`⚠️ ${issues.length} issues detected`);
|
||||
|
||||
console.log('➡️ Posting comment indicating issues...');
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: pr.number,
|
||||
body: `Our automation found one or more issues with this submission:\n`
|
||||
+ issues.map(i => `- ${i.replace('\n', '\n ')}`).join('\n'),
|
||||
});
|
||||
|
||||
console.log("➡️ Applying label 'invalid'...");
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: pr.number,
|
||||
labels: ['invalid'],
|
||||
});
|
||||
|
||||
if (close) {
|
||||
console.log('➡️ Auto-closing PR...');
|
||||
await github.rest.pulls.update({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
pull_number: pr.number,
|
||||
state: 'closed',
|
||||
});
|
||||
}
|
||||
}
|
||||
96
.github/workflows/autogpt-ci.yml
vendored
96
.github/workflows/autogpt-ci.yml
vendored
@@ -20,6 +20,7 @@ concurrency:
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
working-directory: autogpts/autogpt
|
||||
|
||||
jobs:
|
||||
@@ -30,12 +31,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
@@ -44,7 +45,7 @@ jobs:
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v3
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('autogpts/autogpt/pyproject.toml') }}-${{ steps.get_date.outputs.date }}
|
||||
@@ -77,24 +78,43 @@ jobs:
|
||||
test:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.10"]
|
||||
|
||||
services:
|
||||
minio:
|
||||
image: minio/minio:edge-cicd
|
||||
ports:
|
||||
- 9000:9000
|
||||
options: >
|
||||
--health-interval=10s --health-timeout=5s --health-retries=3
|
||||
--health-cmd="curl -f http://localhost:9000/minio/health/live"
|
||||
platform-os: [ubuntu, macos, macos-arm64, windows]
|
||||
runs-on: ${{ matrix.platform-os != 'macos-arm64' && format('{0}-latest', matrix.platform-os) || 'macos-14' }}
|
||||
|
||||
steps:
|
||||
# Quite slow on macOS (2~4 minutes to set up Docker)
|
||||
# - name: Set up Docker (macOS)
|
||||
# if: runner.os == 'macOS'
|
||||
# uses: crazy-max/ghaction-setup-docker@v3
|
||||
|
||||
- name: Start MinIO service (Linux)
|
||||
if: runner.os == 'Linux'
|
||||
working-directory: '.'
|
||||
run: |
|
||||
docker pull minio/minio:edge-cicd
|
||||
docker run -d -p 9000:9000 minio/minio:edge-cicd
|
||||
|
||||
- name: Start MinIO service (macOS)
|
||||
if: runner.os == 'macOS'
|
||||
working-directory: ${{ runner.temp }}
|
||||
run: |
|
||||
brew install minio/stable/minio
|
||||
mkdir data
|
||||
minio server ./data &
|
||||
|
||||
# No MinIO on Windows:
|
||||
# - Windows doesn't support running Linux Docker containers
|
||||
# - It doesn't seem possible to start background processes on Windows. They are
|
||||
# killed after the step returns.
|
||||
# See: https://github.com/actions/runner/issues/598#issuecomment-2011890429
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
@@ -136,7 +156,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
@@ -145,15 +165,34 @@ jobs:
|
||||
run: echo "date=$(date +'%Y-%m-%d')" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set up Python dependency cache
|
||||
uses: actions/cache@v3
|
||||
# On Windows, unpacking cached dependencies takes longer than just installing them
|
||||
if: runner.os != 'Windows'
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: ~/.cache/pypoetry
|
||||
key: ${{ runner.os }}-poetry-${{ hashFiles('autogpts/autogpt/pyproject.toml') }}-${{ steps.get_date.outputs.date }}
|
||||
path: ${{ runner.os == 'macOS' && '~/Library/Caches/pypoetry' || '~/.cache/pypoetry' }}
|
||||
key: poetry-${{ runner.os }}-${{ hashFiles('autogpts/autogpt/poetry.lock') }}
|
||||
|
||||
- name: Install Python dependencies
|
||||
- name: Install Poetry (Unix)
|
||||
if: runner.os != 'Windows'
|
||||
run: |
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
poetry install
|
||||
|
||||
if [ "${{ runner.os }}" = "macOS" ]; then
|
||||
PATH="$HOME/.local/bin:$PATH"
|
||||
echo "$HOME/.local/bin" >> $GITHUB_PATH
|
||||
fi
|
||||
|
||||
- name: Install Poetry (Windows)
|
||||
if: runner.os == 'Windows'
|
||||
shell: pwsh
|
||||
run: |
|
||||
(Invoke-WebRequest -Uri https://install.python-poetry.org -UseBasicParsing).Content | python -
|
||||
|
||||
$env:PATH += ";$env:APPDATA\Python\Scripts"
|
||||
echo "$env:APPDATA\Python\Scripts" >> $env:GITHUB_PATH
|
||||
|
||||
- name: Install Python dependencies
|
||||
run: poetry install
|
||||
|
||||
- name: Run pytest with coverage
|
||||
run: |
|
||||
@@ -165,12 +204,15 @@ jobs:
|
||||
CI: true
|
||||
PLAIN_OUTPUT: True
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
S3_ENDPOINT_URL: http://localhost:9000
|
||||
S3_ENDPOINT_URL: ${{ runner.os != 'Windows' && 'http://127.0.0.1:9000' || '' }}
|
||||
AWS_ACCESS_KEY_ID: minioadmin
|
||||
AWS_SECRET_ACCESS_KEY: minioadmin
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
uses: codecov/codecov-action@v3
|
||||
uses: codecov/codecov-action@v4
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
flags: autogpt-agent,${{ runner.os }}
|
||||
|
||||
- id: setup_git_auth
|
||||
name: Set up git token authentication
|
||||
@@ -178,7 +220,11 @@ jobs:
|
||||
if: success() || failure()
|
||||
run: |
|
||||
config_key="http.${{ github.server_url }}/.extraheader"
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||
if [ "${{ runner.os }}" = 'macOS' ]; then
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64)
|
||||
else
|
||||
base64_pat=$(echo -n "pat:${{ secrets.PAT_REVIEW }}" | base64 -w0)
|
||||
fi
|
||||
|
||||
git config "$config_key" \
|
||||
"Authorization: Basic $base64_pat"
|
||||
@@ -239,12 +285,12 @@ jobs:
|
||||
echo "Adding label and comment..."
|
||||
echo $TOKEN | gh auth login --with-token
|
||||
gh issue edit $PR_NUMBER --add-label "behaviour change"
|
||||
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
gh issue comment $PR_NUMBER --body "You changed AutoGPT's behaviour on ${{ runner.os }}. The cassettes have been updated and will be merged to the submodule when this Pull Request gets merged."
|
||||
fi
|
||||
|
||||
- name: Upload logs to artifact
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-logs
|
||||
path: autogpts/autogpt/logs/
|
||||
|
||||
@@ -16,14 +16,14 @@ jobs:
|
||||
build-type: [release, dev]
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- id: build
|
||||
name: Build image
|
||||
uses: docker/build-push-action@v3
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: autogpts/autogpt
|
||||
build-args: BUILD_TYPE=${{ matrix.build-type }}
|
||||
|
||||
21
.github/workflows/autogpt-docker-ci.yml
vendored
21
.github/workflows/autogpt-docker-ci.yml
vendored
@@ -24,7 +24,7 @@ defaults:
|
||||
|
||||
env:
|
||||
IMAGE_NAME: auto-gpt
|
||||
DEPLOY_IMAGE_NAME: ${{ secrets.DOCKER_USER }}/auto-gpt
|
||||
DEPLOY_IMAGE_NAME: ${{ secrets.DOCKER_USER && format('{0}/', secrets.DOCKER_USER) || '' }}auto-gpt
|
||||
DEV_IMAGE_TAG: latest-dev
|
||||
|
||||
jobs:
|
||||
@@ -35,10 +35,10 @@ jobs:
|
||||
build-type: [release, dev]
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- if: runner.debug
|
||||
run: |
|
||||
@@ -47,11 +47,12 @@ jobs:
|
||||
|
||||
- id: build
|
||||
name: Build image
|
||||
uses: docker/build-push-action@v3
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: autogpts/autogpt
|
||||
build-args: BUILD_TYPE=${{ matrix.build-type }}
|
||||
tags: ${{ env.IMAGE_NAME }}
|
||||
labels: GIT_REVISION=${{ github.sha }}
|
||||
load: true # save to docker images
|
||||
# cache layers in GitHub Actions cache to speed up builds
|
||||
cache-from: type=gha,scope=autogpt-docker-${{ matrix.build-type }}
|
||||
@@ -100,28 +101,30 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Check out repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
|
||||
- name: Log in to Docker hub
|
||||
uses: docker/login-action@v2
|
||||
- if: github.event_name == 'push'
|
||||
name: Log in to Docker hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- id: build
|
||||
name: Build image
|
||||
uses: docker/build-push-action@v3
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: autogpts/autogpt
|
||||
build-args: BUILD_TYPE=dev # include pytest
|
||||
tags: >
|
||||
${{ env.IMAGE_NAME }},
|
||||
${{ env.DEPLOY_IMAGE_NAME }}:${{ env.DEV_IMAGE_TAG }}
|
||||
labels: GIT_REVISION=${{ github.sha }}
|
||||
load: true # save to docker images
|
||||
# cache layers in GitHub Actions cache to speed up builds
|
||||
cache-from: type=gha,scope=autogpt-docker-dev
|
||||
|
||||
9
.github/workflows/autogpt-docker-release.yml
vendored
9
.github/workflows/autogpt-docker-release.yml
vendored
@@ -24,16 +24,16 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Log in to Docker hub
|
||||
uses: docker/login-action@v2
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_PASSWORD }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
# slashes are not allowed in image tags, but can appear in git branch or tag names
|
||||
- id: sanitize_tag
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
|
||||
- id: build
|
||||
name: Build image
|
||||
uses: docker/build-push-action@v3
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: autogpts/autogpt
|
||||
build-args: BUILD_TYPE=release
|
||||
@@ -56,6 +56,7 @@ jobs:
|
||||
${{ env.IMAGE_NAME }},
|
||||
${{ env.DEPLOY_IMAGE_NAME }}:latest,
|
||||
${{ env.DEPLOY_IMAGE_NAME }}:${{ steps.sanitize_tag.outputs.tag }}
|
||||
labels: GIT_REVISION=${{ github.sha }}
|
||||
|
||||
# cache layers in GitHub Actions cache to speed up builds
|
||||
cache-from: ${{ !inputs.no_cache && 'type=gha' || '' }},scope=autogpt-docker-release
|
||||
|
||||
97
.github/workflows/autogpts-benchmark.yml
vendored
Normal file
97
.github/workflows/autogpts-benchmark.yml
vendored
Normal file
@@ -0,0 +1,97 @@
|
||||
name: AutoGPTs Nightly Benchmark
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
schedule:
|
||||
- cron: '0 2 * * *'
|
||||
|
||||
jobs:
|
||||
benchmark:
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
agent-name: [ autogpt ]
|
||||
fail-fast: false
|
||||
timeout-minutes: 120
|
||||
env:
|
||||
min-python-version: '3.10'
|
||||
REPORTS_BRANCH: data/benchmark-reports
|
||||
REPORTS_FOLDER: ${{ format('benchmark/reports/{0}', matrix.agent-name) }}
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
- name: Install Poetry
|
||||
run: curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Prepare reports folder
|
||||
run: mkdir -p ${{ env.REPORTS_FOLDER }}
|
||||
|
||||
- run: poetry -C benchmark install
|
||||
|
||||
- name: Benchmark ${{ matrix.agent-name }}
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd autogpts/${{ matrix.agent-name }}
|
||||
|
||||
set +e # Do not quit on non-zero exit codes
|
||||
poetry run agbenchmark run -N 3 \
|
||||
--test=ReadFile \
|
||||
--test=BasicRetrieval --test=RevenueRetrieval2 \
|
||||
--test=CombineCsv --test=LabelCsv --test=AnswerQuestionCombineCsv \
|
||||
--test=UrlShortener --test=TicTacToe --test=Battleship \
|
||||
--test=WebArenaTask_0 --test=WebArenaTask_21 --test=WebArenaTask_124 \
|
||||
--test=WebArenaTask_134 --test=WebArenaTask_163
|
||||
|
||||
# Convert exit code 1 (some challenges failed) to exit code 0
|
||||
if [ $? -eq 0 ] || [ $? -eq 1 ]; then
|
||||
exit 0
|
||||
else
|
||||
exit $?
|
||||
fi
|
||||
env:
|
||||
AGENT_NAME: ${{ matrix.agent-name }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||
REPORTS_FOLDER: ${{ format('../../{0}', env.REPORTS_FOLDER) }} # account for changed workdir
|
||||
|
||||
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
|
||||
- name: Push reports to data branch
|
||||
run: |
|
||||
# BODGE: Remove success_rate.json and regression_tests.json to avoid conflicts on checkout
|
||||
rm ${{ env.REPORTS_FOLDER }}/*.json
|
||||
|
||||
# Find folder with newest (untracked) report in it
|
||||
report_subfolder=$(find ${{ env.REPORTS_FOLDER }} -type f -name 'report.json' \
|
||||
| xargs -I {} dirname {} \
|
||||
| xargs -I {} git ls-files --others --exclude-standard {} \
|
||||
| xargs -I {} dirname {} \
|
||||
| sort -u)
|
||||
json_report_file="$report_subfolder/report.json"
|
||||
|
||||
# Convert JSON report to Markdown
|
||||
markdown_report_file="$report_subfolder/report.md"
|
||||
poetry -C benchmark run benchmark/reports/format.py "$json_report_file" > "$markdown_report_file"
|
||||
cat "$markdown_report_file" >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
git config --global user.name 'GitHub Actions'
|
||||
git config --global user.email 'github-actions@agpt.co'
|
||||
git fetch origin ${{ env.REPORTS_BRANCH }}:${{ env.REPORTS_BRANCH }} \
|
||||
&& git checkout ${{ env.REPORTS_BRANCH }} \
|
||||
|| git checkout --orphan ${{ env.REPORTS_BRANCH }}
|
||||
git reset --hard
|
||||
git add ${{ env.REPORTS_FOLDER }}
|
||||
git commit -m "Benchmark report for ${{ matrix.agent-name }} @ $(date +'%Y-%m-%d')" \
|
||||
&& git push origin ${{ env.REPORTS_BRANCH }}
|
||||
11
.github/workflows/autogpts-ci.yml
vendored
11
.github/workflows/autogpts-ci.yml
vendored
@@ -37,13 +37,13 @@ jobs:
|
||||
min-python-version: '3.10'
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
@@ -56,13 +56,14 @@ jobs:
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
cd autogpts/${{ matrix.agent-name }}
|
||||
poetry run agbenchmark --mock
|
||||
poetry run agbenchmark --mock --test=BasicRetrieval --test=Battleship --test=WebArenaTask_0
|
||||
poetry run agbenchmark --test=WriteFile
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AGENT_NAME: ${{ matrix.agent-name }}
|
||||
HELICONE_API_KEY: ${{ secrets.HELICONE_API_KEY }}
|
||||
REQUESTS_CA_BUNDLE: /etc/ssl/certs/ca-certificates.crt
|
||||
HELICONE_CACHE_ENABLED: false
|
||||
HELICONE_PROPERTY_AGENT: ${{ matrix.agent-name }}
|
||||
REPORT_LOCATION: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||
REPORTS_FOLDER: ${{ format('../../reports/{0}', matrix.agent-name) }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
|
||||
15
.github/workflows/benchmark-ci.yml
vendored
15
.github/workflows/benchmark-ci.yml
vendored
@@ -23,12 +23,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
@@ -78,13 +78,13 @@ jobs:
|
||||
timeout-minutes: 20
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
@@ -96,11 +96,10 @@ jobs:
|
||||
- name: Run regression tests
|
||||
run: |
|
||||
./run agent start ${{ matrix.agent-name }}
|
||||
sleep 10
|
||||
cd autogpts/${{ matrix.agent-name }}
|
||||
|
||||
set +e # Ignore non-zero exit codes and continue execution
|
||||
echo "Running the following command: poetry run agbenchmark --maintain --mock"
|
||||
|
||||
poetry run agbenchmark --maintain --mock
|
||||
EXIT_CODE=$?
|
||||
set -e # Stop ignoring non-zero exit codes
|
||||
@@ -127,7 +126,7 @@ jobs:
|
||||
|
||||
poetry run agbenchmark --mock
|
||||
poetry run pytest -vv -s tests
|
||||
|
||||
|
||||
CHANGED=$(git diff --name-only | grep -E '(agbenchmark/challenges)|(../frontend/assets)') || echo "No diffs"
|
||||
if [ ! -z "$CHANGED" ]; then
|
||||
echo "There are unstaged changes please run agbenchmark and commit those changes since they are needed."
|
||||
@@ -138,3 +137,5 @@ jobs:
|
||||
fi
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
TELEMETRY_ENVIRONMENT: autogpt-benchmark-ci
|
||||
TELEMETRY_OPT_IN: ${{ github.ref_name == 'master' }}
|
||||
|
||||
@@ -10,13 +10,13 @@ jobs:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v2
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
|
||||
46
.github/workflows/build-frontend.yml
vendored
46
.github/workflows/build-frontend.yml
vendored
@@ -1,46 +0,0 @@
|
||||
name: Build and Commit Frontend
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- development
|
||||
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||
paths:
|
||||
- 'frontend/**'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v2
|
||||
- name: Setup Flutter
|
||||
uses: subosito/flutter-action@v1
|
||||
with:
|
||||
flutter-version: '3.13.2'
|
||||
- name: Build Flutter Web
|
||||
run: |
|
||||
cd frontend
|
||||
flutter build web --base-href /app/
|
||||
- name: Set branch name
|
||||
id: vars
|
||||
run: echo "::set-output name=branch::frontend_build_${GITHUB_SHA}"
|
||||
- name: Commit and Push
|
||||
run: |
|
||||
git config --local user.email "action@github.com"
|
||||
git config --local user.name "GitHub Action"
|
||||
git add frontend/build/web
|
||||
git commit -m "Update frontend build" -a
|
||||
git checkout -b ${{ steps.vars.outputs.branch }}
|
||||
echo "Commit hash: ${GITHUB_SHA}"
|
||||
git push origin ${{ steps.vars.outputs.branch }}
|
||||
# - name: Create Pull Request
|
||||
# uses: peter-evans/create-pull-request@v3
|
||||
# with:
|
||||
# title: "Update frontend build"
|
||||
# body: "This PR updates the frontend build."
|
||||
# branch: ${{ steps.vars.outputs.branch }}
|
||||
# base: "master"
|
||||
2
.github/workflows/close-stale-issues.yml
vendored
2
.github/workflows/close-stale-issues.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
stale:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/stale@v8
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
# operations-per-run: 5000
|
||||
stale-issue-message: >
|
||||
|
||||
60
.github/workflows/frontend-ci.yml
vendored
Normal file
60
.github/workflows/frontend-ci.yml
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
name: Frontend CI/CD
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- development
|
||||
- 'ci-test*' # This will match any branch that starts with "ci-test"
|
||||
paths:
|
||||
- 'frontend/**'
|
||||
- '.github/workflows/frontend-ci.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'frontend/**'
|
||||
- '.github/workflows/frontend-ci.yml'
|
||||
|
||||
jobs:
|
||||
build:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
BUILD_BRANCH: ${{ format('frontend-build/{0}', github.ref_name) }}
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Flutter
|
||||
uses: subosito/flutter-action@v2
|
||||
with:
|
||||
flutter-version: '3.13.2'
|
||||
|
||||
- name: Build Flutter to Web
|
||||
run: |
|
||||
cd frontend
|
||||
flutter build web --base-href /app/
|
||||
|
||||
# - name: Commit and Push to ${{ env.BUILD_BRANCH }}
|
||||
# if: github.event_name == 'push'
|
||||
# run: |
|
||||
# git config --local user.email "action@github.com"
|
||||
# git config --local user.name "GitHub Action"
|
||||
# git add frontend/build/web
|
||||
# git checkout -B ${{ env.BUILD_BRANCH }}
|
||||
# git commit -m "Update frontend build to ${GITHUB_SHA:0:7}" -a
|
||||
# git push -f origin ${{ env.BUILD_BRANCH }}
|
||||
|
||||
- name: Create PR ${{ env.BUILD_BRANCH }} -> ${{ github.ref_name }}
|
||||
if: github.event_name == 'push'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
add-paths: frontend/build/web
|
||||
base: ${{ github.ref_name }}
|
||||
branch: ${{ env.BUILD_BRANCH }}
|
||||
delete-branch: true
|
||||
title: "Update frontend build in `${{ github.ref_name }}`"
|
||||
body: "This PR updates the frontend build based on commit ${{ github.sha }}."
|
||||
commit-message: "Update frontend build based on commit ${{ github.sha }}"
|
||||
6
.github/workflows/hackathon.yml
vendored
6
.github/workflows/hackathon.yml
vendored
@@ -88,13 +88,13 @@ jobs:
|
||||
run: docker ps
|
||||
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
submodules: true
|
||||
|
||||
- name: Set up Python ${{ env.min-python-version }}
|
||||
uses: actions/setup-python@v2
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ env.min-python-version }}
|
||||
|
||||
@@ -107,7 +107,7 @@ jobs:
|
||||
curl -sSL https://install.python-poetry.org | python -
|
||||
|
||||
- name: Install Node.js
|
||||
uses: actions/setup-node@v1
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: v18.15
|
||||
|
||||
|
||||
15
.github/workflows/pr-label.yml
vendored
15
.github/workflows/pr-label.yml
vendored
@@ -52,6 +52,15 @@ jobs:
|
||||
l_label: 'size/l'
|
||||
l_max_size: 500
|
||||
xl_label: 'size/xl'
|
||||
message_if_xl: >
|
||||
This PR exceeds the recommended size of 500 lines.
|
||||
Please make sure you are NOT addressing multiple issues with one PR.
|
||||
message_if_xl:
|
||||
|
||||
scope:
|
||||
if: ${{ github.event_name == 'pull_request_target' }}
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v5
|
||||
with:
|
||||
sync-labels: true
|
||||
|
||||
6
.pr_agent.toml
Normal file
6
.pr_agent.toml
Normal file
@@ -0,0 +1,6 @@
|
||||
[pr_reviewer]
|
||||
num_code_suggestions=0
|
||||
|
||||
[pr_code_suggestions]
|
||||
commitable_code_suggestions=false
|
||||
num_code_suggestions=0
|
||||
@@ -77,37 +77,47 @@ After executing the above commands, running `./run setup` should work successful
|
||||
|
||||
#### Store Project Files within the WSL File System
|
||||
If you continue to experience issues, consider storing your project files within the WSL file system instead of the Windows file system. This method avoids issues related to path translations and permissions and provides a more consistent development environment.
|
||||
|
||||
You can keep running the command to get feedback on where you are up to with your setup.
|
||||
When setup has been completed, the command will return an output like this:
|
||||
|
||||

|
||||
You can keep running the command to get feedback on where you are up to with your setup.
|
||||
When setup has been completed, the command will return an output like this:
|
||||
|
||||

|
||||
|
||||
## Creating Your Agent
|
||||
|
||||
Now setup has been completed its time to create your agent template.
|
||||
Do so by running the `./run agent create YOUR_AGENT_NAME` replacing YOUR_AGENT_NAME with a name of your choice. Examples of valid names: swiftyosgpt or SwiftyosAgent or swiftyos_agent
|
||||
After completing the setup, the next step is to create your agent template.
|
||||
Execute the command `./run agent create YOUR_AGENT_NAME`, where `YOUR_AGENT_NAME` should be replaced with a name of your choosing.
|
||||
|
||||

|
||||
Tips for naming your agent:
|
||||
* Give it its own unique name, or name it after yourself
|
||||
* Include an important aspect of your agent in the name, such as its purpose
|
||||
|
||||
Upon creating your agent its time to officially enter the Arena!
|
||||
Do so by running `./run arena enter YOUR_AGENT_NAME`
|
||||
Examples: `SwiftyosAssistant`, `PwutsPRAgent`, `Narvis`, `evo.ninja`
|
||||
|
||||

|
||||

|
||||
|
||||
> Note: for advanced users, create a new branch and create a file called YOUR_AGENT_NAME.json in the arena directory. Then commit this and create a PR to merge into the main repo. Only single file entries will be permitted. The json file needs the following format.
|
||||
```json
|
||||
{
|
||||
"github_repo_url": "https://github.com/Swiftyos/YourAgentName",
|
||||
"timestamp": "2023-09-18T10:03:38.051498",
|
||||
"commit_hash_to_benchmark": "ac36f7bfc7f23ad8800339fa55943c1405d80d5e",
|
||||
"branch_to_benchmark": "master"
|
||||
}
|
||||
```
|
||||
- github_repo_url: the url to your fork
|
||||
- timestamp: timestamp of the last update of this file
|
||||
- commit_hash_to_benchmark: the commit hash of your entry. You update each time you have an something ready to be officially entered into the hackathon
|
||||
- branch_to_benchmark: the branch you are using to develop your agent on, default is master.
|
||||
### Optional: Entering the Arena
|
||||
|
||||
Entering the Arena is an optional step intended for those who wish to actively participate in the agent leaderboard. If you decide to participate, you can enter the Arena by running `./run arena enter YOUR_AGENT_NAME`. This step is not mandatory for the development or testing of your agent.
|
||||
|
||||
Entries with names like `agent`, `ExampleAgent`, `test_agent` or `MyExampleGPT` will NOT be merged. We also don't accept copycat entries that use the name of other projects, like `AutoGPT` or `evo.ninja`.
|
||||
|
||||

|
||||
|
||||
> **Note**
|
||||
> For advanced users, create a new branch and create a file called YOUR_AGENT_NAME.json in the arena directory. Then commit this and create a PR to merge into the main repo. Only single file entries will be permitted. The json file needs the following format:
|
||||
> ```json
|
||||
> {
|
||||
> "github_repo_url": "https://github.com/Swiftyos/YourAgentName",
|
||||
> "timestamp": "2023-09-18T10:03:38.051498",
|
||||
> "commit_hash_to_benchmark": "ac36f7bfc7f23ad8800339fa55943c1405d80d5e",
|
||||
> "branch_to_benchmark": "master"
|
||||
> }
|
||||
> ```
|
||||
> - `github_repo_url`: the url to your fork
|
||||
> - `timestamp`: timestamp of the last update of this file
|
||||
> - `commit_hash_to_benchmark`: the commit hash of your entry. You update each time you have an something ready to be officially entered into the hackathon
|
||||
> - `branch_to_benchmark`: the branch you are using to develop your agent on, default is master.
|
||||
|
||||
|
||||
## Running your Agent
|
||||
|
||||
10
README.md
10
README.md
@@ -102,7 +102,11 @@ To maintain a uniform standard and ensure seamless compatibility with many curre
|
||||
---
|
||||
|
||||
<p align="center">
|
||||
<a href="https://star-history.com/#Significant-Gravitas/AutoGPT&Date">
|
||||
<img src="https://api.star-history.com/svg?repos=Significant-Gravitas/AutoGPT&type=Date" alt="Star History Chart">
|
||||
</a>
|
||||
<a href="https://star-history.com/#Significant-Gravitas/AutoGPT">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://api.star-history.com/svg?repos=Significant-Gravitas/AutoGPT&type=Date&theme=dark" />
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://api.star-history.com/svg?repos=Significant-Gravitas/AutoGPT&type=Date" />
|
||||
<img alt="Star History Chart" src="https://api.star-history.com/svg?repos=Significant-Gravitas/AutoGPT&type=Date" />
|
||||
</picture>
|
||||
</a>
|
||||
</p>
|
||||
|
||||
66
SECURITY.md
Normal file
66
SECURITY.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Security Policy
|
||||
|
||||
- [**Using AutoGPT Securely**](#using-AutoGPT-securely)
|
||||
- [Restrict Workspace](#restrict-workspace)
|
||||
- [Untrusted inputs](#untrusted-inputs)
|
||||
- [Data privacy](#data-privacy)
|
||||
- [Untrusted environments or networks](#untrusted-environments-or-networks)
|
||||
- [Multi-Tenant environments](#multi-tenant-environments)
|
||||
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
|
||||
|
||||
## Using AutoGPT Securely
|
||||
|
||||
### Restrict Workspace
|
||||
|
||||
Since agents can read and write files, it is important to keep them restricted to a specific workspace. This happens by default *unless* RESTRICT_TO_WORKSPACE is set to False.
|
||||
|
||||
Disabling RESTRICT_TO_WORKSPACE can increase security risks. However, if you still need to disable it, consider running AutoGPT inside a [sandbox](https://developers.google.com/code-sandboxing), to mitigate some of these risks.
|
||||
|
||||
### Untrusted inputs
|
||||
|
||||
When handling untrusted inputs, it's crucial to isolate the execution and carefully pre-process inputs to mitigate script injection risks.
|
||||
|
||||
For maximum security when handling untrusted inputs, you may need to employ the following:
|
||||
|
||||
* Sandboxing: Isolate the process.
|
||||
* Updates: Keep your libraries (including AutoGPT) updated with the latest security patches.
|
||||
* Input Sanitation: Before feeding data to the model, sanitize inputs rigorously. This involves techniques such as:
|
||||
* Validation: Enforce strict rules on allowed characters and data types.
|
||||
* Filtering: Remove potentially malicious scripts or code fragments.
|
||||
* Encoding: Convert special characters into safe representations.
|
||||
* Verification: Run tooling that identifies potential script injections (e.g. [models that detect prompt injection attempts](https://python.langchain.com/docs/guides/safety/hugging_face_prompt_injection)).
|
||||
|
||||
### Data privacy
|
||||
|
||||
To protect sensitive data from potential leaks or unauthorized access, it is crucial to sandbox the agent execution. This means running it in a secure, isolated environment, which helps mitigate many attack vectors.
|
||||
|
||||
### Untrusted environments or networks
|
||||
|
||||
Since AutoGPT performs network calls to the OpenAI API, it is important to always run it with trusted environments and networks. Running it on untrusted environments can expose your API KEY to attackers.
|
||||
Additionally, running it on an untrusted network can expose your data to potential network attacks.
|
||||
|
||||
However, even when running on trusted networks, it is important to always encrypt sensitive data while sending it over the network.
|
||||
|
||||
### Multi-Tenant environments
|
||||
|
||||
If you intend to run multiple AutoGPT brains in parallel, it is your responsibility to ensure the models do not interact or access each other's data.
|
||||
|
||||
The primary areas of concern are tenant isolation, resource allocation, model sharing and hardware attacks.
|
||||
|
||||
- Tenant Isolation: you must make sure that the tenants run separately to prevent unwanted access to the data from other tenants. Keeping model network traffic separate is also important because you not only prevent unauthorized access to data, but also prevent malicious users or tenants sending prompts to execute under another tenant’s identity.
|
||||
|
||||
- Resource Allocation: a denial of service caused by one tenant can affect the overall system health. Implement safeguards like rate limits, access controls, and health monitoring.
|
||||
|
||||
- Data Sharing: in a multi-tenant design with data sharing, ensure tenants and users understand the security risks and sandbox agent execution to mitigate risks.
|
||||
|
||||
- Hardware Attacks: the hardware (GPUs or TPUs) can also be attacked. [Research](https://scholar.google.com/scholar?q=gpu+side+channel) has shown that side channel attacks on GPUs are possible, which can make data leak from other brains or processes running on the same system at the same time.
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
Beware that none of the topics under [Using AutoGPT Securely](#using-AutoGPT-securely) are considered vulnerabilities on AutoGPT.
|
||||
|
||||
However, If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives us time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released.
|
||||
|
||||
Please disclose it as a private [security advisory](https://github.com/Significant-Gravitas/AutoGPT/security/advisories/new).
|
||||
|
||||
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
|
||||
@@ -2,8 +2,15 @@
|
||||
### AutoGPT - GENERAL SETTINGS
|
||||
################################################################################
|
||||
|
||||
## OPENAI_API_KEY - OpenAI API Key (Example: my-openai-api-key)
|
||||
OPENAI_API_KEY=your-openai-api-key
|
||||
## OPENAI_API_KEY - OpenAI API Key (Example: sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
|
||||
# OPENAI_API_KEY=
|
||||
|
||||
## ANTHROPIC_API_KEY - Anthropic API Key (Example: sk-ant-api03-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx)
|
||||
# ANTHROPIC_API_KEY=
|
||||
|
||||
## TELEMETRY_OPT_IN - Share telemetry on errors and other issues with the AutoGPT team, e.g. through Sentry.
|
||||
## This helps us to spot and solve problems earlier & faster. (Default: DISABLED)
|
||||
# TELEMETRY_OPT_IN=true
|
||||
|
||||
## EXECUTE_LOCAL_COMMANDS - Allow local command execution (Default: False)
|
||||
# EXECUTE_LOCAL_COMMANDS=False
|
||||
@@ -13,15 +20,15 @@ OPENAI_API_KEY=your-openai-api-key
|
||||
## RESTRICT_TO_WORKSPACE - Restrict file operations to workspace ./data/agents/<agent_id>/workspace (Default: True)
|
||||
# RESTRICT_TO_WORKSPACE=True
|
||||
|
||||
## DISABLED_COMMAND_CATEGORIES - The list of categories of commands that are disabled (Default: None)
|
||||
# DISABLED_COMMAND_CATEGORIES=
|
||||
## DISABLED_COMMANDS - The comma separated list of commands that are disabled (Default: None)
|
||||
# DISABLED_COMMANDS=
|
||||
|
||||
## WORKSPACE_BACKEND - Choose a storage backend for workspace contents
|
||||
## FILE_STORAGE_BACKEND - Choose a storage backend for contents
|
||||
## Options: local, gcs, s3
|
||||
# WORKSPACE_BACKEND=local
|
||||
# FILE_STORAGE_BACKEND=local
|
||||
|
||||
## WORKSPACE_STORAGE_BUCKET - GCS/S3 Bucket to store workspace contents in
|
||||
# WORKSPACE_STORAGE_BUCKET=autogpt
|
||||
## STORAGE_BUCKET - GCS/S3 Bucket to store contents in
|
||||
# STORAGE_BUCKET=autogpt
|
||||
|
||||
## GCS Credentials
|
||||
# see https://cloud.google.com/storage/docs/authentication#libauth
|
||||
@@ -40,9 +47,6 @@ OPENAI_API_KEY=your-openai-api-key
|
||||
## AI_SETTINGS_FILE - Specifies which AI Settings file to use, relative to the AutoGPT root directory. (defaults to ai_settings.yaml)
|
||||
# AI_SETTINGS_FILE=ai_settings.yaml
|
||||
|
||||
## PLUGINS_CONFIG_FILE - The path to the plugins_config.yaml file, relative to the AutoGPT root directory. (Default plugins_config.yaml)
|
||||
# PLUGINS_CONFIG_FILE=plugins_config.yaml
|
||||
|
||||
## PROMPT_SETTINGS_FILE - Specifies which Prompt Settings file to use, relative to the AutoGPT root directory. (defaults to prompt_settings.yaml)
|
||||
# PROMPT_SETTINGS_FILE=prompt_settings.yaml
|
||||
|
||||
@@ -86,14 +90,14 @@ OPENAI_API_KEY=your-openai-api-key
|
||||
### LLM MODELS
|
||||
################################################################################
|
||||
|
||||
## SMART_LLM - Smart language model (Default: gpt-4-0314)
|
||||
# SMART_LLM=gpt-4-0314
|
||||
## SMART_LLM - Smart language model (Default: gpt-4-turbo)
|
||||
# SMART_LLM=gpt-4-turbo
|
||||
|
||||
## FAST_LLM - Fast language model (Default: gpt-3.5-turbo-16k)
|
||||
# FAST_LLM=gpt-3.5-turbo-16k
|
||||
## FAST_LLM - Fast language model (Default: gpt-3.5-turbo)
|
||||
# FAST_LLM=gpt-3.5-turbo
|
||||
|
||||
## EMBEDDING_MODEL - Model to use for creating embeddings
|
||||
# EMBEDDING_MODEL=text-embedding-ada-002
|
||||
# EMBEDDING_MODEL=text-embedding-3-small
|
||||
|
||||
################################################################################
|
||||
### SHELL EXECUTION
|
||||
@@ -228,6 +232,8 @@ OPENAI_API_KEY=your-openai-api-key
|
||||
### Agent Protocol Server Settings
|
||||
################################################################################
|
||||
## AP_SERVER_PORT - Specifies what port the agent protocol server will listen on. (Default: 8000)
|
||||
## AP_SERVER_DB_URL - Specifies what connection url the agent protocol database will connect to (Default: Internal SQLite)
|
||||
## AP_SERVER_CORS_ALLOWED_ORIGINS - Comma separated list of allowed origins for CORS. (Default: http://localhost:{AP_SERVER_PORT})
|
||||
# AP_SERVER_PORT=8000
|
||||
# # AP_SERVER_DB_URL - Specifies what connection url the agent protocol database will connect to (Default: Internal SQLite)
|
||||
# AP_SERVER_DB_URL=sqlite:///data/ap_server.db
|
||||
# AP_SERVER_CORS_ALLOWED_ORIGINS=
|
||||
|
||||
@@ -22,6 +22,11 @@ repos:
|
||||
- id: black
|
||||
language_version: python3.10
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 7.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
|
||||
# - repo: https://github.com/pre-commit/mirrors-mypy
|
||||
# rev: 'v1.3.0'
|
||||
# hooks:
|
||||
|
||||
@@ -3,12 +3,12 @@ import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from autogpt.agent_manager.agent_manager import AgentManager
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.app.main import _configure_openai_provider, run_interaction_loop
|
||||
from autogpt.commands import COMMAND_CATEGORIES
|
||||
from autogpt.app.main import _configure_llm_provider, run_interaction_loop
|
||||
from autogpt.config import AIProfile, ConfigBuilder
|
||||
from autogpt.file_storage import FileStorageBackendName, get_storage
|
||||
from autogpt.logs.config import configure_logging
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
LOG_DIR = Path(__file__).parent / "logs"
|
||||
|
||||
@@ -19,29 +19,27 @@ def run_specific_agent(task: str, continuous_mode: bool = False) -> None:
|
||||
|
||||
|
||||
def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
config = ConfigBuilder.build_config_from_env()
|
||||
config.logging.level = logging.DEBUG
|
||||
config.logging.log_dir = LOG_DIR
|
||||
config.logging.plain_console_output = True
|
||||
configure_logging(**config.logging.dict())
|
||||
configure_logging(
|
||||
level=logging.DEBUG,
|
||||
log_dir=LOG_DIR,
|
||||
plain_console_output=True,
|
||||
)
|
||||
|
||||
config = ConfigBuilder.build_config_from_env()
|
||||
config.continuous_mode = continuous_mode
|
||||
config.continuous_limit = 20
|
||||
config.noninteractive_mode = True
|
||||
config.memory_backend = "no_memory"
|
||||
|
||||
command_registry = CommandRegistry.with_command_modules(COMMAND_CATEGORIES, config)
|
||||
|
||||
ai_profile = AIProfile(
|
||||
ai_name="AutoGPT",
|
||||
ai_role="a multi-purpose AI assistant.",
|
||||
ai_goals=[task],
|
||||
)
|
||||
|
||||
agent_prompt_config = Agent.default_settings.prompt_config.copy(deep=True)
|
||||
agent_prompt_config.use_functions_api = config.openai_functions
|
||||
agent_settings = AgentSettings(
|
||||
name=Agent.default_settings.name,
|
||||
agent_id=AgentManager.generate_id("AutoGPT-benchmark"),
|
||||
description=Agent.default_settings.description,
|
||||
ai_profile=ai_profile,
|
||||
config=AgentConfiguration(
|
||||
@@ -49,19 +47,23 @@ def bootstrap_agent(task: str, continuous_mode: bool) -> Agent:
|
||||
smart_llm=config.smart_llm,
|
||||
allow_fs_access=not config.restrict_to_workspace,
|
||||
use_functions_api=config.openai_functions,
|
||||
plugins=config.plugins,
|
||||
),
|
||||
prompt_config=agent_prompt_config,
|
||||
history=Agent.default_settings.history.copy(deep=True),
|
||||
)
|
||||
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
agent = Agent(
|
||||
settings=agent_settings,
|
||||
llm_provider=_configure_openai_provider(config),
|
||||
command_registry=command_registry,
|
||||
llm_provider=_configure_llm_provider(config),
|
||||
file_storage=file_storage,
|
||||
legacy_config=config,
|
||||
)
|
||||
agent.attach_fs(config.app_data_dir / "agents" / "AutoGPT-benchmark") # HACK
|
||||
return agent
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
from typing import Optional
|
||||
|
||||
from autogpt.agent_manager import AgentManager
|
||||
from autogpt.agents.agent import Agent, AgentConfiguration, AgentSettings
|
||||
from autogpt.commands import COMMAND_CATEGORIES
|
||||
from autogpt.config import AIDirectives, AIProfile, Config
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
from autogpt.logs.config import configure_chat_plugins
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
from autogpt.plugins import scan_plugins
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
|
||||
|
||||
def create_agent(
|
||||
agent_id: str,
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
app_config: Config,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
directives: Optional[AIDirectives] = None,
|
||||
) -> Agent:
|
||||
@@ -23,26 +21,28 @@ def create_agent(
|
||||
directives = AIDirectives.from_file(app_config.prompt_settings_file)
|
||||
|
||||
agent = _configure_agent(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=directives,
|
||||
app_config=app_config,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
agent.state.agent_id = AgentManager.generate_id(agent.ai_profile.ai_name)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def configure_agent_with_state(
|
||||
state: AgentSettings,
|
||||
app_config: Config,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
) -> Agent:
|
||||
return _configure_agent(
|
||||
state=state,
|
||||
app_config=app_config,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
@@ -50,26 +50,21 @@ def configure_agent_with_state(
|
||||
def _configure_agent(
|
||||
app_config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
file_storage: FileStorage,
|
||||
agent_id: str = "",
|
||||
task: str = "",
|
||||
ai_profile: Optional[AIProfile] = None,
|
||||
directives: Optional[AIDirectives] = None,
|
||||
state: Optional[AgentSettings] = None,
|
||||
) -> Agent:
|
||||
if not (state or task and ai_profile and directives):
|
||||
if not (state or agent_id and task and ai_profile and directives):
|
||||
raise TypeError(
|
||||
"Either (state) or (task, ai_profile, directives) must be specified"
|
||||
"Either (state) or (agent_id, task, ai_profile, directives)"
|
||||
" must be specified"
|
||||
)
|
||||
|
||||
app_config.plugins = scan_plugins(app_config)
|
||||
configure_chat_plugins(app_config)
|
||||
|
||||
# Create a CommandRegistry instance and scan default folder
|
||||
command_registry = CommandRegistry.with_command_modules(
|
||||
modules=COMMAND_CATEGORIES,
|
||||
config=app_config,
|
||||
)
|
||||
|
||||
agent_state = state or create_agent_state(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=directives,
|
||||
@@ -81,21 +76,20 @@ def _configure_agent(
|
||||
return Agent(
|
||||
settings=agent_state,
|
||||
llm_provider=llm_provider,
|
||||
command_registry=command_registry,
|
||||
file_storage=file_storage,
|
||||
legacy_config=app_config,
|
||||
)
|
||||
|
||||
|
||||
def create_agent_state(
|
||||
agent_id: str,
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
directives: AIDirectives,
|
||||
app_config: Config,
|
||||
) -> AgentSettings:
|
||||
agent_prompt_config = Agent.default_settings.prompt_config.copy(deep=True)
|
||||
agent_prompt_config.use_functions_api = app_config.openai_functions
|
||||
|
||||
return AgentSettings(
|
||||
agent_id=agent_id,
|
||||
name=Agent.default_settings.name,
|
||||
description=Agent.default_settings.description,
|
||||
task=task,
|
||||
@@ -106,8 +100,6 @@ def create_agent_state(
|
||||
smart_llm=app_config.smart_llm,
|
||||
allow_fs_access=not app_config.restrict_to_workspace,
|
||||
use_functions_api=app_config.openai_functions,
|
||||
plugins=app_config.plugins,
|
||||
),
|
||||
prompt_config=agent_prompt_config,
|
||||
history=Agent.default_settings.history.copy(deep=True),
|
||||
)
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt.config.ai_directives import AIDirectives
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
|
||||
from .configurators import _configure_agent
|
||||
from .profile_generator import generate_agent_profile_for_task
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers.schema import ChatModelProvider
|
||||
|
||||
from autogpt.config.ai_directives import AIDirectives
|
||||
|
||||
from .configurators import _configure_agent
|
||||
from .profile_generator import generate_agent_profile_for_task
|
||||
|
||||
|
||||
async def generate_agent_for_task(
|
||||
agent_id: str,
|
||||
task: str,
|
||||
app_config: "Config",
|
||||
llm_provider: "ChatModelProvider",
|
||||
) -> "Agent":
|
||||
app_config: Config,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
) -> Agent:
|
||||
base_directives = AIDirectives.from_file(app_config.prompt_settings_file)
|
||||
ai_profile, task_directives = await generate_agent_profile_for_task(
|
||||
task=task,
|
||||
@@ -23,9 +28,11 @@ async def generate_agent_for_task(
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
return _configure_agent(
|
||||
agent_id=agent_id,
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=base_directives + task_directives,
|
||||
app_config=app_config,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
|
||||
@@ -8,7 +8,6 @@ from autogpt.core.prompting import (
|
||||
LanguageModelClassification,
|
||||
PromptStrategy,
|
||||
)
|
||||
from autogpt.core.prompting.utils import json_loads
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
@@ -203,9 +202,7 @@ class AgentProfileGenerator(PromptStrategy):
|
||||
f"LLM did not call {self._create_agent_function.name} function; "
|
||||
"agent profile creation failed"
|
||||
)
|
||||
arguments: object = json_loads(
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
arguments: object = response_content.tool_calls[0].function.arguments
|
||||
ai_profile = AIProfile(
|
||||
ai_name=arguments.get("name"),
|
||||
ai_role=arguments.get("description"),
|
||||
@@ -238,18 +235,14 @@ async def generate_agent_profile_for_task(
|
||||
prompt = agent_profile_generator.build_prompt(task)
|
||||
|
||||
# Call LLM with the string as user input
|
||||
output = (
|
||||
await llm_provider.create_chat_completion(
|
||||
prompt.messages,
|
||||
model_name=app_config.smart_llm,
|
||||
functions=prompt.functions,
|
||||
)
|
||||
).response
|
||||
output = await llm_provider.create_chat_completion(
|
||||
prompt.messages,
|
||||
model_name=app_config.smart_llm,
|
||||
functions=prompt.functions,
|
||||
completion_parser=agent_profile_generator.parse_response_content,
|
||||
)
|
||||
|
||||
# Debug LLM Output
|
||||
logger.debug(f"AI Config Generator Raw Output: {output}")
|
||||
logger.debug(f"AI Config Generator Raw Output: {output.response}")
|
||||
|
||||
# Parse the output
|
||||
ai_profile, ai_directives = agent_profile_generator.parse_response_content(output)
|
||||
|
||||
return ai_profile, ai_directives
|
||||
return output.parsed_result
|
||||
|
||||
@@ -2,47 +2,44 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import AgentSettings
|
||||
|
||||
from autogpt.agents.utils.agent_file_manager import AgentFileManager
|
||||
from autogpt.agents.agent import AgentSettings
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
|
||||
|
||||
class AgentManager:
|
||||
def __init__(self, app_data_dir: Path):
|
||||
self.agents_dir = app_data_dir / "agents"
|
||||
if not self.agents_dir.exists():
|
||||
self.agents_dir.mkdir()
|
||||
def __init__(self, file_storage: FileStorage):
|
||||
self.file_manager = file_storage.clone_with_subroot("agents")
|
||||
|
||||
@staticmethod
|
||||
def generate_id(agent_name: str) -> str:
|
||||
"""Generate a unique ID for an agent given agent name."""
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
return f"{agent_name}-{unique_id}"
|
||||
|
||||
def list_agents(self) -> list[str]:
|
||||
return [
|
||||
dir.name
|
||||
for dir in self.agents_dir.iterdir()
|
||||
if dir.is_dir() and AgentFileManager(dir).state_file_path.exists()
|
||||
]
|
||||
"""Return all agent directories within storage."""
|
||||
agent_dirs: list[str] = []
|
||||
for file_path in self.file_manager.list_files():
|
||||
if len(file_path.parts) == 2 and file_path.name == "state.json":
|
||||
agent_dirs.append(file_path.parent.name)
|
||||
return agent_dirs
|
||||
|
||||
def get_agent_dir(self, agent_id: str, must_exist: bool = False) -> Path:
|
||||
def get_agent_dir(self, agent_id: str) -> Path:
|
||||
"""Return the directory of the agent with the given ID."""
|
||||
assert len(agent_id) > 0
|
||||
agent_dir = self.agents_dir / agent_id
|
||||
if must_exist and not agent_dir.exists():
|
||||
agent_dir: Path | None = None
|
||||
if self.file_manager.exists(agent_id):
|
||||
agent_dir = self.file_manager.root / agent_id
|
||||
else:
|
||||
raise FileNotFoundError(f"No agent with ID '{agent_id}'")
|
||||
return agent_dir
|
||||
|
||||
def retrieve_state(self, agent_id: str) -> AgentSettings:
|
||||
from autogpt.agents.agent import AgentSettings
|
||||
|
||||
agent_dir = self.get_agent_dir(agent_id, True)
|
||||
state_file = AgentFileManager(agent_dir).state_file_path
|
||||
if not state_file.exists():
|
||||
def load_agent_state(self, agent_id: str) -> AgentSettings:
|
||||
"""Load the state of the agent with the given ID."""
|
||||
state_file_path = Path(agent_id) / "state.json"
|
||||
if not self.file_manager.exists(state_file_path):
|
||||
raise FileNotFoundError(f"Agent with ID '{agent_id}' has no state.json")
|
||||
|
||||
state = AgentSettings.load_from_json_file(state_file)
|
||||
state.agent_data_dir = agent_dir
|
||||
return state
|
||||
state = self.file_manager.read_file(state_file_path)
|
||||
return AgentSettings.parse_raw(state)
|
||||
|
||||
37
autogpts/autogpt/autogpt/agents/README.md
Normal file
37
autogpts/autogpt/autogpt/agents/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# 🤖 Agents
|
||||
|
||||
Agent is composed of [🧩 Components](./components.md) and responsible for executing pipelines and some additional logic. The base class for all agents is `BaseAgent`, it has the necessary logic to collect components and execute protocols.
|
||||
|
||||
## Important methods
|
||||
|
||||
`BaseAgent` provides two abstract methods needed for any agent to work properly:
|
||||
1. `propose_action`: This method is responsible for proposing an action based on the current state of the agent, it returns `ThoughtProcessOutput`.
|
||||
2. `execute`: This method is responsible for executing the proposed action, returns `ActionResult`.
|
||||
|
||||
## AutoGPT Agent
|
||||
|
||||
`Agent` is the main agent provided by AutoGPT. It's a subclass of `BaseAgent`. It has all the [Built-in Components](./built-in-components.md). `Agent` implements the essential abstract methods from `BaseAgent`: `propose_action` and `execute`.
|
||||
|
||||
## Building your own Agent
|
||||
|
||||
The easiest way to build your own agent is to extend the `Agent` class and add additional components. By doing this you can reuse the existing components and the default logic for executing [⚙️ Protocols](./protocols.md).
|
||||
|
||||
```py
|
||||
class MyComponent(AgentComponent):
|
||||
pass
|
||||
|
||||
class MyAgent(Agent):
|
||||
def __init__(
|
||||
self,
|
||||
settings: AgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
# Call the parent constructor to bring in the default components
|
||||
super().__init__(settings, llm_provider, file_storage, legacy_config)
|
||||
# Add your custom component
|
||||
self.my_component = MyComponent()
|
||||
```
|
||||
|
||||
For more customization, you can override the `propose_action` and `execute` or even subclass `BaseAgent` directly. This way you can have full control over the agent's components and behavior. Have a look at the [implementation of Agent](https://github.com/Significant-Gravitas/AutoGPT/tree/master/autogpts/autogpt/autogpt/agents/agent.py) for more details.
|
||||
@@ -1,4 +1,9 @@
|
||||
from .agent import Agent
|
||||
from .base import AgentThoughts, BaseAgent, CommandArgs, CommandName
|
||||
from .agent import Agent, OneShotAgentActionProposal
|
||||
from .base import BaseAgent, BaseAgentActionProposal
|
||||
|
||||
__all__ = ["BaseAgent", "Agent", "CommandName", "CommandArgs", "AgentThoughts"]
|
||||
__all__ = [
|
||||
"BaseAgent",
|
||||
"Agent",
|
||||
"BaseAgentActionProposal",
|
||||
"OneShotAgentActionProposal",
|
||||
]
|
||||
|
||||
@@ -2,56 +2,71 @@ from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
|
||||
import sentry_sdk
|
||||
from pydantic import Field
|
||||
|
||||
from autogpt.commands.execute_code import CodeExecutorComponent
|
||||
from autogpt.commands.git_operations import GitOperationsComponent
|
||||
from autogpt.commands.image_gen import ImageGeneratorComponent
|
||||
from autogpt.commands.system import SystemComponent
|
||||
from autogpt.commands.user_interaction import UserInteractionComponent
|
||||
from autogpt.commands.web_search import WebSearchComponent
|
||||
from autogpt.commands.web_selenium import WebSeleniumComponent
|
||||
from autogpt.components.event_history import EventHistoryComponent
|
||||
from autogpt.core.configuration import Configurable
|
||||
from autogpt.core.prompting import ChatPrompt
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
ChatMessage,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
)
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.llm.providers.openai import function_specs_from_commands
|
||||
from autogpt.logs.log_cycle import (
|
||||
CURRENT_CONTEXT_FILE_NAME,
|
||||
NEXT_ACTION_FILE_NAME,
|
||||
USER_INPUT_FILE_NAME,
|
||||
LogCycleHandler,
|
||||
)
|
||||
from autogpt.logs.utils import fmt_kwargs
|
||||
from autogpt.models.action_history import (
|
||||
Action,
|
||||
ActionErrorResult,
|
||||
ActionInterruptedByHuman,
|
||||
ActionResult,
|
||||
ActionSuccessResult,
|
||||
EpisodicActionHistory,
|
||||
)
|
||||
from autogpt.models.command import CommandOutput
|
||||
from autogpt.models.context_item import ContextItem
|
||||
|
||||
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
from .features.context import ContextMixin
|
||||
from .features.file_workspace import FileWorkspaceMixin
|
||||
from .features.watchdog import WatchdogMixin
|
||||
from .prompt_strategies.one_shot import (
|
||||
OneShotAgentPromptConfiguration,
|
||||
OneShotAgentPromptStrategy,
|
||||
)
|
||||
from .utils.exceptions import (
|
||||
from autogpt.models.command import Command, CommandOutput
|
||||
from autogpt.utils.exceptions import (
|
||||
AgentException,
|
||||
AgentTerminated,
|
||||
CommandExecutionError,
|
||||
UnknownCommandError,
|
||||
)
|
||||
|
||||
from .base import BaseAgent, BaseAgentConfiguration, BaseAgentSettings
|
||||
from .features.agent_file_manager import FileManagerComponent
|
||||
from .features.context import ContextComponent
|
||||
from .features.watchdog import WatchdogComponent
|
||||
from .prompt_strategies.one_shot import (
|
||||
OneShotAgentActionProposal,
|
||||
OneShotAgentPromptStrategy,
|
||||
)
|
||||
from .protocols import (
|
||||
AfterExecute,
|
||||
AfterParse,
|
||||
CommandProvider,
|
||||
DirectiveProvider,
|
||||
MessageProvider,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -61,47 +76,64 @@ class AgentConfiguration(BaseAgentConfiguration):
|
||||
|
||||
class AgentSettings(BaseAgentSettings):
|
||||
config: AgentConfiguration = Field(default_factory=AgentConfiguration)
|
||||
prompt_config: OneShotAgentPromptConfiguration = Field(
|
||||
default_factory=(
|
||||
lambda: OneShotAgentPromptStrategy.default_configuration.copy(deep=True)
|
||||
)
|
||||
|
||||
history: EpisodicActionHistory[OneShotAgentActionProposal] = Field(
|
||||
default_factory=EpisodicActionHistory[OneShotAgentActionProposal]
|
||||
)
|
||||
"""(STATE) The action history of the agent."""
|
||||
|
||||
|
||||
class Agent(
|
||||
ContextMixin,
|
||||
FileWorkspaceMixin,
|
||||
WatchdogMixin,
|
||||
BaseAgent,
|
||||
Configurable[AgentSettings],
|
||||
):
|
||||
"""AutoGPT's primary Agent; uses one-shot prompting."""
|
||||
|
||||
class Agent(BaseAgent, Configurable[AgentSettings]):
|
||||
default_settings: AgentSettings = AgentSettings(
|
||||
name="Agent",
|
||||
description=__doc__,
|
||||
description=__doc__ if __doc__ else "",
|
||||
)
|
||||
|
||||
prompt_strategy: OneShotAgentPromptStrategy
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: AgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
command_registry: CommandRegistry,
|
||||
file_storage: FileStorage,
|
||||
legacy_config: Config,
|
||||
):
|
||||
prompt_strategy = OneShotAgentPromptStrategy(
|
||||
configuration=settings.prompt_config,
|
||||
logger=logger,
|
||||
super().__init__(settings)
|
||||
|
||||
self.llm_provider = llm_provider
|
||||
self.ai_profile = settings.ai_profile
|
||||
self.directives = settings.directives
|
||||
prompt_config = OneShotAgentPromptStrategy.default_configuration.copy(deep=True)
|
||||
prompt_config.use_functions_api = (
|
||||
settings.config.use_functions_api
|
||||
# Anthropic currently doesn't support tools + prefilling :(
|
||||
and self.llm.provider_name != "anthropic"
|
||||
)
|
||||
super().__init__(
|
||||
settings=settings,
|
||||
llm_provider=llm_provider,
|
||||
prompt_strategy=prompt_strategy,
|
||||
command_registry=command_registry,
|
||||
legacy_config=legacy_config,
|
||||
self.prompt_strategy = OneShotAgentPromptStrategy(prompt_config, logger)
|
||||
self.commands: list[Command] = []
|
||||
|
||||
# Components
|
||||
self.system = SystemComponent(legacy_config, settings.ai_profile)
|
||||
self.history = EventHistoryComponent(
|
||||
settings.history,
|
||||
self.send_token_limit,
|
||||
lambda x: self.llm_provider.count_tokens(x, self.llm.name),
|
||||
legacy_config,
|
||||
llm_provider,
|
||||
)
|
||||
self.user_interaction = UserInteractionComponent(legacy_config)
|
||||
self.file_manager = FileManagerComponent(settings, file_storage)
|
||||
self.code_executor = CodeExecutorComponent(
|
||||
self.file_manager.workspace,
|
||||
settings,
|
||||
legacy_config,
|
||||
)
|
||||
self.git_ops = GitOperationsComponent(legacy_config)
|
||||
self.image_gen = ImageGeneratorComponent(
|
||||
self.file_manager.workspace, legacy_config
|
||||
)
|
||||
self.web_search = WebSearchComponent(legacy_config)
|
||||
self.web_selenium = WebSeleniumComponent(legacy_config, llm_provider, self.llm)
|
||||
self.context = ContextComponent(self.file_manager.workspace)
|
||||
self.watchdog = WatchdogComponent(settings.config, settings.history)
|
||||
|
||||
self.created_at = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
"""Timestamp the agent was created; only used for structured debug logging."""
|
||||
@@ -109,199 +141,153 @@ class Agent(
|
||||
self.log_cycle_handler = LogCycleHandler()
|
||||
"""LogCycleHandler for structured debug logging."""
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
*args,
|
||||
extra_messages: Optional[list[ChatMessage]] = None,
|
||||
include_os_info: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> ChatPrompt:
|
||||
if not extra_messages:
|
||||
extra_messages = []
|
||||
self.event_history = settings.history
|
||||
self.legacy_config = legacy_config
|
||||
|
||||
# Clock
|
||||
extra_messages.append(
|
||||
ChatMessage.system(f"The current time and date is {time.strftime('%c')}"),
|
||||
async def propose_action(self) -> OneShotAgentActionProposal:
|
||||
"""Proposes the next action to execute, based on the task and current state.
|
||||
|
||||
Returns:
|
||||
The command name and arguments, if any, and the agent's thoughts.
|
||||
"""
|
||||
self.reset_trace()
|
||||
|
||||
# Get directives
|
||||
resources = await self.run_pipeline(DirectiveProvider.get_resources)
|
||||
constraints = await self.run_pipeline(DirectiveProvider.get_constraints)
|
||||
best_practices = await self.run_pipeline(DirectiveProvider.get_best_practices)
|
||||
|
||||
directives = self.state.directives.copy(deep=True)
|
||||
directives.resources += resources
|
||||
directives.constraints += constraints
|
||||
directives.best_practices += best_practices
|
||||
|
||||
# Get commands
|
||||
self.commands = await self.run_pipeline(CommandProvider.get_commands)
|
||||
self._remove_disabled_commands()
|
||||
|
||||
# Get messages
|
||||
messages = await self.run_pipeline(MessageProvider.get_messages)
|
||||
|
||||
prompt: ChatPrompt = self.prompt_strategy.build_prompt(
|
||||
messages=messages,
|
||||
task=self.state.task,
|
||||
ai_profile=self.state.ai_profile,
|
||||
ai_directives=directives,
|
||||
commands=function_specs_from_commands(self.commands),
|
||||
include_os_info=self.legacy_config.execute_local_commands,
|
||||
)
|
||||
|
||||
# Add budget information (if any) to prompt
|
||||
api_manager = ApiManager()
|
||||
if api_manager.get_total_budget() > 0.0:
|
||||
remaining_budget = (
|
||||
api_manager.get_total_budget() - api_manager.get_total_cost()
|
||||
)
|
||||
if remaining_budget < 0:
|
||||
remaining_budget = 0
|
||||
|
||||
budget_msg = ChatMessage.system(
|
||||
f"Your remaining API budget is ${remaining_budget:.3f}"
|
||||
+ (
|
||||
" BUDGET EXCEEDED! SHUT DOWN!\n\n"
|
||||
if remaining_budget == 0
|
||||
else " Budget very nearly exceeded! Shut down gracefully!\n\n"
|
||||
if remaining_budget < 0.005
|
||||
else " Budget nearly exceeded. Finish up.\n\n"
|
||||
if remaining_budget < 0.01
|
||||
else ""
|
||||
),
|
||||
)
|
||||
logger.debug(budget_msg)
|
||||
extra_messages.append(budget_msg)
|
||||
|
||||
if include_os_info is None:
|
||||
include_os_info = self.legacy_config.execute_local_commands
|
||||
|
||||
return super().build_prompt(
|
||||
*args,
|
||||
extra_messages=extra_messages,
|
||||
include_os_info=include_os_info,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def on_before_think(self, *args, **kwargs) -> ChatPrompt:
|
||||
prompt = super().on_before_think(*args, **kwargs)
|
||||
|
||||
self.log_cycle_handler.log_count_within_cycle = 0
|
||||
self.log_cycle_handler.log_cycle(
|
||||
self.ai_profile.ai_name,
|
||||
self.state.ai_profile.ai_name,
|
||||
self.created_at,
|
||||
self.config.cycle_count,
|
||||
prompt.raw(),
|
||||
CURRENT_CONTEXT_FILE_NAME,
|
||||
)
|
||||
return prompt
|
||||
|
||||
def parse_and_process_response(
|
||||
self, llm_response: AssistantChatMessage, *args, **kwargs
|
||||
) -> Agent.ThoughtProcessOutput:
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_post_planning():
|
||||
continue
|
||||
llm_response.content = plugin.post_planning(llm_response.content or "")
|
||||
logger.debug(f"Executing prompt:\n{dump_prompt(prompt)}")
|
||||
output = await self.complete_and_parse(prompt)
|
||||
self.config.cycle_count += 1
|
||||
|
||||
(
|
||||
command_name,
|
||||
arguments,
|
||||
assistant_reply_dict,
|
||||
) = self.prompt_strategy.parse_response_content(llm_response)
|
||||
return output
|
||||
|
||||
async def complete_and_parse(
|
||||
self, prompt: ChatPrompt, exception: Optional[Exception] = None
|
||||
) -> OneShotAgentActionProposal:
|
||||
if exception:
|
||||
prompt.messages.append(ChatMessage.system(f"Error: {exception}"))
|
||||
|
||||
response: ChatModelResponse[
|
||||
OneShotAgentActionProposal
|
||||
] = await self.llm_provider.create_chat_completion(
|
||||
prompt.messages,
|
||||
model_name=self.llm.name,
|
||||
completion_parser=self.prompt_strategy.parse_response_content,
|
||||
functions=prompt.functions,
|
||||
prefill_response=prompt.prefill_response,
|
||||
)
|
||||
result = response.parsed_result
|
||||
|
||||
self.log_cycle_handler.log_cycle(
|
||||
self.ai_profile.ai_name,
|
||||
self.state.ai_profile.ai_name,
|
||||
self.created_at,
|
||||
self.config.cycle_count,
|
||||
assistant_reply_dict,
|
||||
result.thoughts.dict(),
|
||||
NEXT_ACTION_FILE_NAME,
|
||||
)
|
||||
|
||||
if command_name:
|
||||
self.event_history.register_action(
|
||||
Action(
|
||||
name=command_name,
|
||||
args=arguments,
|
||||
reasoning=assistant_reply_dict["thoughts"]["reasoning"],
|
||||
)
|
||||
)
|
||||
|
||||
return command_name, arguments, assistant_reply_dict
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command_name: str,
|
||||
command_args: dict[str, str] = {},
|
||||
user_input: str = "",
|
||||
) -> ActionResult:
|
||||
result: ActionResult
|
||||
|
||||
if command_name == "human_feedback":
|
||||
result = ActionInterruptedByHuman(feedback=user_input)
|
||||
self.log_cycle_handler.log_cycle(
|
||||
self.ai_profile.ai_name,
|
||||
self.created_at,
|
||||
self.config.cycle_count,
|
||||
user_input,
|
||||
USER_INPUT_FILE_NAME,
|
||||
)
|
||||
|
||||
else:
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_pre_command():
|
||||
continue
|
||||
command_name, command_args = plugin.pre_command(
|
||||
command_name, command_args
|
||||
)
|
||||
|
||||
try:
|
||||
return_value = await execute_command(
|
||||
command_name=command_name,
|
||||
arguments=command_args,
|
||||
agent=self,
|
||||
)
|
||||
|
||||
# Intercept ContextItem if one is returned by the command
|
||||
if type(return_value) is tuple and isinstance(
|
||||
return_value[1], ContextItem
|
||||
):
|
||||
context_item = return_value[1]
|
||||
return_value = return_value[0]
|
||||
logger.debug(
|
||||
f"Command {command_name} returned a ContextItem: {context_item}"
|
||||
)
|
||||
self.context.add(context_item)
|
||||
|
||||
result = ActionSuccessResult(outputs=return_value)
|
||||
except AgentTerminated:
|
||||
raise
|
||||
except AgentException as e:
|
||||
result = ActionErrorResult.from_exception(e)
|
||||
logger.warning(
|
||||
f"{command_name}({fmt_kwargs(command_args)}) raised an error: {e}"
|
||||
)
|
||||
|
||||
result_tlength = self.llm_provider.count_tokens(str(result), self.llm.name)
|
||||
if result_tlength > self.send_token_limit // 3:
|
||||
result = ActionErrorResult(
|
||||
reason=f"Command {command_name} returned too much output. "
|
||||
"Do not execute this command again with the same arguments."
|
||||
)
|
||||
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_post_command():
|
||||
continue
|
||||
if result.status == "success":
|
||||
result.outputs = plugin.post_command(command_name, result.outputs)
|
||||
elif result.status == "error":
|
||||
result.reason = plugin.post_command(command_name, result.reason)
|
||||
|
||||
# Update action history
|
||||
self.event_history.register_result(result)
|
||||
await self.run_pipeline(AfterParse.after_parse, result)
|
||||
|
||||
return result
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
proposal: OneShotAgentActionProposal,
|
||||
user_feedback: str = "",
|
||||
) -> ActionResult:
|
||||
tool = proposal.use_tool
|
||||
|
||||
#############
|
||||
# Utilities #
|
||||
#############
|
||||
# Get commands
|
||||
self.commands = await self.run_pipeline(CommandProvider.get_commands)
|
||||
self._remove_disabled_commands()
|
||||
|
||||
|
||||
async def execute_command(
|
||||
command_name: str,
|
||||
arguments: dict[str, str],
|
||||
agent: Agent,
|
||||
) -> CommandOutput:
|
||||
"""Execute the command and return the result
|
||||
|
||||
Args:
|
||||
command_name (str): The name of the command to execute
|
||||
arguments (dict): The arguments for the command
|
||||
agent (Agent): The agent that is executing the command
|
||||
|
||||
Returns:
|
||||
str: The result of the command
|
||||
"""
|
||||
# Execute a native command with the same name or alias, if it exists
|
||||
if command := agent.command_registry.get_command(command_name):
|
||||
try:
|
||||
result = command(**arguments, agent=agent)
|
||||
return_value = await self._execute_tool(tool)
|
||||
|
||||
result = ActionSuccessResult(outputs=return_value)
|
||||
except AgentTerminated:
|
||||
raise
|
||||
except AgentException as e:
|
||||
result = ActionErrorResult.from_exception(e)
|
||||
logger.warning(f"{tool} raised an error: {e}")
|
||||
sentry_sdk.capture_exception(e)
|
||||
|
||||
result_tlength = self.llm_provider.count_tokens(str(result), self.llm.name)
|
||||
if result_tlength > self.send_token_limit // 3:
|
||||
result = ActionErrorResult(
|
||||
reason=f"Command {tool.name} returned too much output. "
|
||||
"Do not execute this command again with the same arguments."
|
||||
)
|
||||
|
||||
await self.run_pipeline(AfterExecute.after_execute, result)
|
||||
|
||||
logger.debug("\n".join(self.trace))
|
||||
|
||||
return result
|
||||
|
||||
async def do_not_execute(
|
||||
self, denied_proposal: OneShotAgentActionProposal, user_feedback: str
|
||||
) -> ActionResult:
|
||||
result = ActionInterruptedByHuman(feedback=user_feedback)
|
||||
self.log_cycle_handler.log_cycle(
|
||||
self.state.ai_profile.ai_name,
|
||||
self.created_at,
|
||||
self.config.cycle_count,
|
||||
user_feedback,
|
||||
USER_INPUT_FILE_NAME,
|
||||
)
|
||||
|
||||
await self.run_pipeline(AfterExecute.after_execute, result)
|
||||
|
||||
logger.debug("\n".join(self.trace))
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_tool(self, tool_call: AssistantFunctionCall) -> CommandOutput:
|
||||
"""Execute the command and return the result
|
||||
|
||||
Args:
|
||||
tool_call (AssistantFunctionCall): The tool call to execute
|
||||
|
||||
Returns:
|
||||
str: The execution result
|
||||
"""
|
||||
# Execute a native command with the same name or alias, if it exists
|
||||
command = self._get_command(tool_call.name)
|
||||
try:
|
||||
result = command(**tool_call.arguments)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
@@ -310,20 +296,31 @@ async def execute_command(
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(str(e))
|
||||
|
||||
# Handle non-native commands (e.g. from plugins)
|
||||
if agent._prompt_scratchpad:
|
||||
for name, command in agent._prompt_scratchpad.commands.items():
|
||||
if (
|
||||
command_name == name
|
||||
or command_name.lower() == command.description.lower()
|
||||
):
|
||||
try:
|
||||
return command.method(**arguments)
|
||||
except AgentException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(str(e))
|
||||
def _get_command(self, command_name: str) -> Command:
|
||||
for command in reversed(self.commands):
|
||||
if command_name in command.names:
|
||||
return command
|
||||
|
||||
raise UnknownCommandError(
|
||||
f"Cannot execute command '{command_name}': unknown command."
|
||||
)
|
||||
raise UnknownCommandError(
|
||||
f"Cannot execute command '{command_name}': unknown command."
|
||||
)
|
||||
|
||||
def _remove_disabled_commands(self) -> None:
|
||||
self.commands = [
|
||||
command
|
||||
for command in self.commands
|
||||
if not any(
|
||||
name in self.legacy_config.disabled_commands for name in command.names
|
||||
)
|
||||
]
|
||||
|
||||
def find_obscured_commands(self) -> list[Command]:
|
||||
seen_names = set()
|
||||
obscured_commands = []
|
||||
for command in reversed(self.commands):
|
||||
# If all of the command's names have been seen, it's obscured
|
||||
if seen_names.issuperset(command.names):
|
||||
obscured_commands.append(command)
|
||||
else:
|
||||
seen_names.update(command.names)
|
||||
return list(reversed(obscured_commands))
|
||||
|
||||
@@ -1,25 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Iterator,
|
||||
Optional,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from pydantic import Field, validator
|
||||
from colorama import Fore
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.prompting.base import PromptStrategy
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
)
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
from autogpt.models.action_history import ActionResult
|
||||
|
||||
from autogpt.agents.utils.prompt_scratchpad import PromptScratchpad
|
||||
from autogpt.agents import protocols as _protocols
|
||||
from autogpt.agents.components import (
|
||||
AgentComponent,
|
||||
ComponentEndpointError,
|
||||
EndpointPipelineError,
|
||||
)
|
||||
from autogpt.config import ConfigBuilder
|
||||
from autogpt.config.ai_directives import AIDirectives
|
||||
from autogpt.config.ai_profile import AIProfile
|
||||
@@ -29,34 +39,26 @@ from autogpt.core.configuration import (
|
||||
SystemSettings,
|
||||
UserConfigurable,
|
||||
)
|
||||
from autogpt.core.prompting.schema import (
|
||||
ChatMessage,
|
||||
ChatPrompt,
|
||||
CompletionModelFunction,
|
||||
from autogpt.core.resource.model_providers import (
|
||||
CHAT_MODELS,
|
||||
AssistantFunctionCall,
|
||||
ModelName,
|
||||
)
|
||||
from autogpt.core.resource.model_providers.openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OpenAIModelName,
|
||||
)
|
||||
from autogpt.core.runner.client_lib.logging.helpers import dump_prompt
|
||||
from autogpt.llm.providers.openai import get_openai_command_specs
|
||||
from autogpt.models.action_history import ActionResult, EpisodicActionHistory
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIModelName
|
||||
from autogpt.models.utils import ModelWithSummary
|
||||
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
|
||||
|
||||
from .utils.agent_file_manager import AgentFileManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CommandName = str
|
||||
CommandArgs = dict[str, str]
|
||||
AgentThoughts = dict[str, Any]
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
class BaseAgentConfiguration(SystemConfiguration):
|
||||
allow_fs_access: bool = UserConfigurable(default=False)
|
||||
|
||||
fast_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
|
||||
smart_llm: OpenAIModelName = UserConfigurable(default=OpenAIModelName.GPT4)
|
||||
fast_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT3_16k)
|
||||
smart_llm: ModelName = UserConfigurable(default=OpenAIModelName.GPT4)
|
||||
use_functions_api: bool = UserConfigurable(default=False)
|
||||
|
||||
default_cycle_instruction: str = DEFAULT_TRIGGERING_PROMPT
|
||||
@@ -92,21 +94,6 @@ class BaseAgentConfiguration(SystemConfiguration):
|
||||
summary_max_tlength: Optional[int] = None
|
||||
# TODO: move to ActionHistoryConfiguration
|
||||
|
||||
plugins: list[AutoGPTPluginTemplate] = Field(default_factory=list, exclude=True)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True # Necessary for plugins
|
||||
|
||||
@validator("plugins", each_item=True)
|
||||
def validate_plugins(cls, p: AutoGPTPluginTemplate | Any):
|
||||
assert issubclass(
|
||||
p.__class__, AutoGPTPluginTemplate
|
||||
), f"{p} does not subclass AutoGPTPluginTemplate"
|
||||
assert (
|
||||
p.__class__.__name__ != "AutoGPTPluginTemplate"
|
||||
), f"Plugins must subclass AutoGPTPluginTemplate; {p} is a template instance"
|
||||
return p
|
||||
|
||||
@validator("use_functions_api")
|
||||
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
|
||||
if v:
|
||||
@@ -126,7 +113,6 @@ class BaseAgentConfiguration(SystemConfiguration):
|
||||
|
||||
class BaseAgentSettings(SystemSettings):
|
||||
agent_id: str = ""
|
||||
agent_data_dir: Optional[Path] = None
|
||||
|
||||
ai_profile: AIProfile = Field(default_factory=lambda: AIProfile(ai_name="AutoGPT"))
|
||||
"""The AI profile or "personality" of the agent."""
|
||||
@@ -144,79 +130,44 @@ class BaseAgentSettings(SystemSettings):
|
||||
config: BaseAgentConfiguration = Field(default_factory=BaseAgentConfiguration)
|
||||
"""The configuration for this BaseAgent subsystem instance."""
|
||||
|
||||
history: EpisodicActionHistory = Field(default_factory=EpisodicActionHistory)
|
||||
"""(STATE) The action history of the agent."""
|
||||
|
||||
def save_to_json_file(self, file_path: Path) -> None:
|
||||
with file_path.open("w") as f:
|
||||
f.write(self.json())
|
||||
|
||||
@classmethod
|
||||
def load_from_json_file(cls, file_path: Path):
|
||||
return cls.parse_file(file_path)
|
||||
class AgentMeta(ABCMeta):
|
||||
def __call__(cls, *args, **kwargs):
|
||||
# Create instance of the class (Agent or BaseAgent)
|
||||
instance = super().__call__(*args, **kwargs)
|
||||
# Automatically collect modules after the instance is created
|
||||
instance._collect_components()
|
||||
return instance
|
||||
|
||||
|
||||
class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
||||
"""Base class for all AutoGPT agent classes."""
|
||||
class BaseAgentActionProposal(BaseModel):
|
||||
thoughts: str | ModelWithSummary
|
||||
use_tool: AssistantFunctionCall = None
|
||||
|
||||
ThoughtProcessOutput = tuple[CommandName, CommandArgs, AgentThoughts]
|
||||
|
||||
class BaseAgent(Configurable[BaseAgentSettings], metaclass=AgentMeta):
|
||||
C = TypeVar("C", bound=AgentComponent)
|
||||
|
||||
default_settings = BaseAgentSettings(
|
||||
name="BaseAgent",
|
||||
description=__doc__,
|
||||
description=__doc__ if __doc__ else "",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: BaseAgentSettings,
|
||||
llm_provider: ChatModelProvider,
|
||||
prompt_strategy: PromptStrategy,
|
||||
command_registry: CommandRegistry,
|
||||
legacy_config: Config,
|
||||
):
|
||||
self.state = settings
|
||||
self.components: list[AgentComponent] = []
|
||||
self.config = settings.config
|
||||
self.ai_profile = settings.ai_profile
|
||||
self.directives = settings.directives
|
||||
self.event_history = settings.history
|
||||
# Execution data for debugging
|
||||
self._trace: list[str] = []
|
||||
|
||||
self.legacy_config = legacy_config
|
||||
"""LEGACY: Monolithic application configuration."""
|
||||
logger.debug(f"Created {__class__} '{self.state.ai_profile.ai_name}'")
|
||||
|
||||
self.file_manager: AgentFileManager = (
|
||||
AgentFileManager(settings.agent_data_dir)
|
||||
if settings.agent_data_dir
|
||||
else None
|
||||
) # type: ignore
|
||||
|
||||
self.llm_provider = llm_provider
|
||||
|
||||
self.prompt_strategy = prompt_strategy
|
||||
|
||||
self.command_registry = command_registry
|
||||
"""The registry containing all commands available to the agent."""
|
||||
|
||||
self._prompt_scratchpad: PromptScratchpad | None = None
|
||||
|
||||
# Support multi-inheritance and mixins for subclasses
|
||||
super(BaseAgent, self).__init__()
|
||||
|
||||
logger.debug(f"Created {__class__} '{self.ai_profile.ai_name}'")
|
||||
|
||||
def set_id(self, new_id: str, new_agent_dir: Optional[Path] = None):
|
||||
self.state.agent_id = new_id
|
||||
if self.state.agent_data_dir:
|
||||
if not new_agent_dir:
|
||||
raise ValueError(
|
||||
"new_agent_dir must be specified if one is currently configured"
|
||||
)
|
||||
self.attach_fs(new_agent_dir)
|
||||
|
||||
def attach_fs(self, agent_dir: Path) -> AgentFileManager:
|
||||
self.file_manager = AgentFileManager(agent_dir)
|
||||
self.file_manager.initialize()
|
||||
self.state.agent_data_dir = agent_dir
|
||||
return self.file_manager
|
||||
@property
|
||||
def trace(self) -> list[str]:
|
||||
return self._trace
|
||||
|
||||
@property
|
||||
def llm(self) -> ChatModelInfo:
|
||||
@@ -224,208 +175,180 @@ class BaseAgent(Configurable[BaseAgentSettings], ABC):
|
||||
llm_name = (
|
||||
self.config.smart_llm if self.config.big_brain else self.config.fast_llm
|
||||
)
|
||||
return OPEN_AI_CHAT_MODELS[llm_name]
|
||||
return CHAT_MODELS[llm_name]
|
||||
|
||||
@property
|
||||
def send_token_limit(self) -> int:
|
||||
return self.config.send_token_limit or self.llm.max_tokens * 3 // 4
|
||||
|
||||
async def propose_action(self) -> ThoughtProcessOutput:
|
||||
"""Proposes the next action to execute, based on the task and current state.
|
||||
|
||||
Returns:
|
||||
The command name and arguments, if any, and the agent's thoughts.
|
||||
"""
|
||||
assert self.file_manager, (
|
||||
f"Agent has no FileManager: call {__class__.__name__}.attach_fs()"
|
||||
" before trying to run the agent."
|
||||
)
|
||||
|
||||
# Scratchpad as surrogate PromptGenerator for plugin hooks
|
||||
self._prompt_scratchpad = PromptScratchpad()
|
||||
|
||||
prompt: ChatPrompt = self.build_prompt(scratchpad=self._prompt_scratchpad)
|
||||
prompt = self.on_before_think(prompt, scratchpad=self._prompt_scratchpad)
|
||||
|
||||
logger.debug(f"Executing prompt:\n{dump_prompt(prompt)}")
|
||||
response = await self.llm_provider.create_chat_completion(
|
||||
prompt.messages,
|
||||
functions=get_openai_command_specs(
|
||||
self.command_registry.list_available_commands(self)
|
||||
)
|
||||
+ list(self._prompt_scratchpad.commands.values())
|
||||
if self.config.use_functions_api
|
||||
else [],
|
||||
model_name=self.llm.name,
|
||||
completion_parser=lambda r: self.parse_and_process_response(
|
||||
r,
|
||||
prompt,
|
||||
scratchpad=self._prompt_scratchpad,
|
||||
),
|
||||
)
|
||||
self.config.cycle_count += 1
|
||||
|
||||
return self.on_response(
|
||||
llm_response=response,
|
||||
prompt=prompt,
|
||||
scratchpad=self._prompt_scratchpad,
|
||||
)
|
||||
@abstractmethod
|
||||
async def propose_action(self) -> BaseAgentActionProposal:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
command_name: str,
|
||||
command_args: dict[str, str] = {},
|
||||
user_input: str = "",
|
||||
proposal: BaseAgentActionProposal,
|
||||
user_feedback: str = "",
|
||||
) -> ActionResult:
|
||||
"""Executes the given command, if any, and returns the agent's response.
|
||||
|
||||
Params:
|
||||
command_name: The name of the command to execute, if any.
|
||||
command_args: The arguments to pass to the command, if any.
|
||||
user_input: The user's input, if any.
|
||||
|
||||
Returns:
|
||||
ActionResult: An object representing the result(s) of the command.
|
||||
"""
|
||||
...
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
scratchpad: PromptScratchpad,
|
||||
extra_commands: Optional[list[CompletionModelFunction]] = None,
|
||||
extra_messages: Optional[list[ChatMessage]] = None,
|
||||
**extras,
|
||||
) -> ChatPrompt:
|
||||
"""Constructs a prompt using `self.prompt_strategy`.
|
||||
|
||||
Params:
|
||||
scratchpad: An object for plugins to write additional prompt elements to.
|
||||
(E.g. commands, constraints, best practices)
|
||||
extra_commands: Additional commands that the agent has access to.
|
||||
extra_messages: Additional messages to include in the prompt.
|
||||
"""
|
||||
if not extra_commands:
|
||||
extra_commands = []
|
||||
if not extra_messages:
|
||||
extra_messages = []
|
||||
|
||||
# Apply additions from plugins
|
||||
for plugin in self.config.plugins:
|
||||
if not plugin.can_handle_post_prompt():
|
||||
continue
|
||||
plugin.post_prompt(scratchpad)
|
||||
ai_directives = self.directives.copy(deep=True)
|
||||
ai_directives.resources += scratchpad.resources
|
||||
ai_directives.constraints += scratchpad.constraints
|
||||
ai_directives.best_practices += scratchpad.best_practices
|
||||
extra_commands += list(scratchpad.commands.values())
|
||||
|
||||
prompt = self.prompt_strategy.build_prompt(
|
||||
task=self.state.task,
|
||||
ai_profile=self.ai_profile,
|
||||
ai_directives=ai_directives,
|
||||
commands=get_openai_command_specs(
|
||||
self.command_registry.list_available_commands(self)
|
||||
)
|
||||
+ extra_commands,
|
||||
event_history=self.event_history,
|
||||
max_prompt_tokens=self.send_token_limit,
|
||||
count_tokens=lambda x: self.llm_provider.count_tokens(x, self.llm.name),
|
||||
count_message_tokens=lambda x: self.llm_provider.count_message_tokens(
|
||||
x, self.llm.name
|
||||
),
|
||||
extra_messages=extra_messages,
|
||||
**extras,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
def on_before_think(
|
||||
self,
|
||||
prompt: ChatPrompt,
|
||||
scratchpad: PromptScratchpad,
|
||||
) -> ChatPrompt:
|
||||
"""Called after constructing the prompt but before executing it.
|
||||
|
||||
Calls the `on_planning` hook of any enabled and capable plugins, adding their
|
||||
output to the prompt.
|
||||
|
||||
Params:
|
||||
prompt: The prompt that is about to be executed.
|
||||
scratchpad: An object for plugins to write additional prompt elements to.
|
||||
(E.g. commands, constraints, best practices)
|
||||
|
||||
Returns:
|
||||
The prompt to execute
|
||||
"""
|
||||
current_tokens_used = self.llm_provider.count_message_tokens(
|
||||
prompt.messages, self.llm.name
|
||||
)
|
||||
plugin_count = len(self.config.plugins)
|
||||
for i, plugin in enumerate(self.config.plugins):
|
||||
if not plugin.can_handle_on_planning():
|
||||
continue
|
||||
plugin_response = plugin.on_planning(scratchpad, prompt.raw())
|
||||
if not plugin_response or plugin_response == "":
|
||||
continue
|
||||
message_to_add = ChatMessage.system(plugin_response)
|
||||
tokens_to_add = self.llm_provider.count_message_tokens(
|
||||
message_to_add, self.llm.name
|
||||
)
|
||||
if current_tokens_used + tokens_to_add > self.send_token_limit:
|
||||
logger.debug(f"Plugin response too long, skipping: {plugin_response}")
|
||||
logger.debug(f"Plugins remaining at stop: {plugin_count - i}")
|
||||
break
|
||||
prompt.messages.insert(
|
||||
-1, message_to_add
|
||||
) # HACK: assumes cycle instruction to be at the end
|
||||
current_tokens_used += tokens_to_add
|
||||
return prompt
|
||||
|
||||
def on_response(
|
||||
self,
|
||||
llm_response: ChatModelResponse,
|
||||
prompt: ChatPrompt,
|
||||
scratchpad: PromptScratchpad,
|
||||
) -> ThoughtProcessOutput:
|
||||
"""Called upon receiving a response from the chat model.
|
||||
|
||||
Calls `self.parse_and_process_response()`.
|
||||
|
||||
Params:
|
||||
llm_response: The raw response from the chat model.
|
||||
prompt: The prompt that was executed.
|
||||
scratchpad: An object containing additional prompt elements from plugins.
|
||||
(E.g. commands, constraints, best practices)
|
||||
|
||||
Returns:
|
||||
The parsed command name and command args, if any, and the agent thoughts.
|
||||
"""
|
||||
|
||||
return llm_response.parsed_result
|
||||
|
||||
# TODO: update memory/context
|
||||
|
||||
@abstractmethod
|
||||
def parse_and_process_response(
|
||||
async def do_not_execute(
|
||||
self,
|
||||
llm_response: AssistantChatMessage,
|
||||
prompt: ChatPrompt,
|
||||
scratchpad: PromptScratchpad,
|
||||
) -> ThoughtProcessOutput:
|
||||
"""Validate, parse & process the LLM's response.
|
||||
denied_proposal: BaseAgentActionProposal,
|
||||
user_feedback: str,
|
||||
) -> ActionResult:
|
||||
...
|
||||
|
||||
Must be implemented by derivative classes: no base implementation is provided,
|
||||
since the implementation depends on the role of the derivative Agent.
|
||||
def reset_trace(self):
|
||||
self._trace = []
|
||||
|
||||
Params:
|
||||
llm_response: The raw response from the chat model.
|
||||
prompt: The prompt that was executed.
|
||||
scratchpad: An object containing additional prompt elements from plugins.
|
||||
(E.g. commands, constraints, best practices)
|
||||
@overload
|
||||
async def run_pipeline(
|
||||
self, protocol_method: Callable[P, Iterator[T]], *args, retry_limit: int = 3
|
||||
) -> list[T]:
|
||||
...
|
||||
|
||||
Returns:
|
||||
The parsed command name and command args, if any, and the agent thoughts.
|
||||
"""
|
||||
pass
|
||||
@overload
|
||||
async def run_pipeline(
|
||||
self, protocol_method: Callable[P, None], *args, retry_limit: int = 3
|
||||
) -> list[None]:
|
||||
...
|
||||
|
||||
async def run_pipeline(
|
||||
self,
|
||||
protocol_method: Callable[P, Iterator[T] | None],
|
||||
*args,
|
||||
retry_limit: int = 3,
|
||||
) -> list[T] | list[None]:
|
||||
method_name = protocol_method.__name__
|
||||
protocol_name = protocol_method.__qualname__.split(".")[0]
|
||||
protocol_class = getattr(_protocols, protocol_name)
|
||||
if not issubclass(protocol_class, AgentComponent):
|
||||
raise TypeError(f"{repr(protocol_method)} is not a protocol method")
|
||||
|
||||
# Clone parameters to revert on failure
|
||||
original_args = self._selective_copy(args)
|
||||
pipeline_attempts = 0
|
||||
method_result: list[T] = []
|
||||
self._trace.append(f"⬇️ {Fore.BLUE}{method_name}{Fore.RESET}")
|
||||
|
||||
while pipeline_attempts < retry_limit:
|
||||
try:
|
||||
for component in self.components:
|
||||
# Skip other protocols
|
||||
if not isinstance(component, protocol_class):
|
||||
continue
|
||||
|
||||
# Skip disabled components
|
||||
if not component.enabled:
|
||||
self._trace.append(
|
||||
f" {Fore.LIGHTBLACK_EX}"
|
||||
f"{component.__class__.__name__}{Fore.RESET}"
|
||||
)
|
||||
continue
|
||||
|
||||
method = getattr(component, method_name, None)
|
||||
if not callable(method):
|
||||
continue
|
||||
|
||||
component_attempts = 0
|
||||
while component_attempts < retry_limit:
|
||||
try:
|
||||
component_args = self._selective_copy(args)
|
||||
if inspect.iscoroutinefunction(method):
|
||||
result = await method(*component_args)
|
||||
else:
|
||||
result = method(*component_args)
|
||||
if result is not None:
|
||||
method_result.extend(result)
|
||||
args = component_args
|
||||
self._trace.append(f"✅ {component.__class__.__name__}")
|
||||
|
||||
except ComponentEndpointError:
|
||||
self._trace.append(
|
||||
f"❌ {Fore.YELLOW}{component.__class__.__name__}: "
|
||||
f"ComponentEndpointError{Fore.RESET}"
|
||||
)
|
||||
# Retry the same component on ComponentEndpointError
|
||||
component_attempts += 1
|
||||
continue
|
||||
# Successful component execution
|
||||
break
|
||||
# Successful pipeline execution
|
||||
break
|
||||
except EndpointPipelineError:
|
||||
self._trace.append(
|
||||
f"❌ {Fore.LIGHTRED_EX}{component.__class__.__name__}: "
|
||||
f"EndpointPipelineError{Fore.RESET}"
|
||||
)
|
||||
# Restart from the beginning on EndpointPipelineError
|
||||
# Revert to original parameters
|
||||
args = self._selective_copy(original_args)
|
||||
pipeline_attempts += 1
|
||||
continue # Start the loop over
|
||||
except Exception as e:
|
||||
raise e
|
||||
return method_result
|
||||
|
||||
def _collect_components(self):
|
||||
components = [
|
||||
getattr(self, attr)
|
||||
for attr in dir(self)
|
||||
if isinstance(getattr(self, attr), AgentComponent)
|
||||
]
|
||||
|
||||
if self.components:
|
||||
# Check if any coponent is missed (added to Agent but not to components)
|
||||
for component in components:
|
||||
if component not in self.components:
|
||||
logger.warning(
|
||||
f"Component {component.__class__.__name__} "
|
||||
"is attached to an agent but not added to components list"
|
||||
)
|
||||
# Skip collecting anf sorting and sort if ordering is explicit
|
||||
return
|
||||
self.components = self._topological_sort(components)
|
||||
|
||||
def _topological_sort(
|
||||
self, components: list[AgentComponent]
|
||||
) -> list[AgentComponent]:
|
||||
visited = set()
|
||||
stack = []
|
||||
|
||||
def visit(node: AgentComponent):
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
for neighbor_class in node.__class__.run_after:
|
||||
# Find the instance of neighbor_class in components
|
||||
neighbor = next(
|
||||
(m for m in components if isinstance(m, neighbor_class)), None
|
||||
)
|
||||
if neighbor:
|
||||
visit(neighbor)
|
||||
stack.append(node)
|
||||
|
||||
for component in components:
|
||||
visit(component)
|
||||
|
||||
return stack
|
||||
|
||||
def _selective_copy(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
|
||||
copied_args = []
|
||||
for item in args:
|
||||
if isinstance(item, list):
|
||||
# Shallow copy for lists
|
||||
copied_item = item[:]
|
||||
elif isinstance(item, dict):
|
||||
# Shallow copy for dicts
|
||||
copied_item = item.copy()
|
||||
elif isinstance(item, BaseModel):
|
||||
# Deep copy for Pydantic models (deep=True to also copy nested models)
|
||||
copied_item = item.copy(deep=True)
|
||||
else:
|
||||
# Deep copy for other objects
|
||||
copied_item = copy.deepcopy(item)
|
||||
copied_args.append(copied_item)
|
||||
return tuple(copied_args)
|
||||
|
||||
35
autogpts/autogpt/autogpt/agents/components.py
Normal file
35
autogpts/autogpt/autogpt/agents/components.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from abc import ABC
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class AgentComponent(ABC):
|
||||
run_after: list[type["AgentComponent"]] = []
|
||||
_enabled: Callable[[], bool] | bool = True
|
||||
_disabled_reason: str = ""
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
if callable(self._enabled):
|
||||
return self._enabled()
|
||||
return self._enabled
|
||||
|
||||
@property
|
||||
def disabled_reason(self) -> str:
|
||||
return self._disabled_reason
|
||||
|
||||
|
||||
class ComponentEndpointError(Exception):
|
||||
"""Error of a single protocol method on a component."""
|
||||
|
||||
def __init__(self, message: str = ""):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EndpointPipelineError(ComponentEndpointError):
|
||||
"""Error of an entire pipline of one endpoint."""
|
||||
|
||||
|
||||
class ComponentSystemError(EndpointPipelineError):
|
||||
"""Error of a group of pipelines;
|
||||
multiple different enpoints."""
|
||||
161
autogpts/autogpt/autogpt/agents/features/agent_file_manager.py
Normal file
161
autogpts/autogpt/autogpt/agents/features/agent_file_manager.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from autogpt.agents.protocols import CommandProvider, DirectiveProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.utils.file_operations_utils import decode_textual_file
|
||||
|
||||
from ..base import BaseAgentSettings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileManagerComponent(DirectiveProvider, CommandProvider):
|
||||
"""
|
||||
Adds general file manager (e.g. Agent state),
|
||||
workspace manager (e.g. Agent output files) support and
|
||||
commands to perform operations on files and folders.
|
||||
"""
|
||||
|
||||
files: FileStorage
|
||||
"""Agent-related files, e.g. state, logs.
|
||||
Use `workspace` to access the agent's workspace files."""
|
||||
|
||||
workspace: FileStorage
|
||||
"""Workspace that the agent has access to, e.g. for reading/writing files.
|
||||
Use `files` to access agent-related files, e.g. state, logs."""
|
||||
|
||||
STATE_FILE = "state.json"
|
||||
"""The name of the file where the agent's state is stored."""
|
||||
|
||||
def __init__(self, state: BaseAgentSettings, file_storage: FileStorage):
|
||||
self.state = state
|
||||
|
||||
if not state.agent_id:
|
||||
raise ValueError("Agent must have an ID.")
|
||||
|
||||
self.files = file_storage.clone_with_subroot(f"agents/{state.agent_id}/")
|
||||
self.workspace = file_storage.clone_with_subroot(
|
||||
f"agents/{state.agent_id}/workspace"
|
||||
)
|
||||
self._file_storage = file_storage
|
||||
|
||||
async def save_state(self, save_as: Optional[str] = None) -> None:
|
||||
"""Save the agent's state to the state file."""
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
if save_as:
|
||||
temp_id = state.agent_id
|
||||
state.agent_id = save_as
|
||||
self._file_storage.make_dir(f"agents/{save_as}")
|
||||
# Save state
|
||||
await self._file_storage.write_file(
|
||||
f"agents/{save_as}/{self.STATE_FILE}", state.json()
|
||||
)
|
||||
# Copy workspace
|
||||
self._file_storage.copy(
|
||||
f"agents/{temp_id}/workspace",
|
||||
f"agents/{save_as}/workspace",
|
||||
)
|
||||
state.agent_id = temp_id
|
||||
else:
|
||||
await self.files.write_file(self.files.root / self.STATE_FILE, state.json())
|
||||
|
||||
def change_agent_id(self, new_id: str):
|
||||
"""Change the agent's ID and update the file storage accordingly."""
|
||||
state: BaseAgentSettings = getattr(self, "state")
|
||||
# Rename the agent's files and workspace
|
||||
self._file_storage.rename(f"agents/{state.agent_id}", f"agents/{new_id}")
|
||||
# Update the file storage objects
|
||||
self.files = self._file_storage.clone_with_subroot(f"agents/{new_id}/")
|
||||
self.workspace = self._file_storage.clone_with_subroot(
|
||||
f"agents/{new_id}/workspace"
|
||||
)
|
||||
state.agent_id = new_id
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "The ability to read and write files."
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.read_file
|
||||
yield self.write_to_file
|
||||
yield self.list_folder
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the file to read",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def read_file(self, filename: str | Path) -> str:
|
||||
"""Read a file and return the contents
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to read
|
||||
|
||||
Returns:
|
||||
str: The contents of the file
|
||||
"""
|
||||
file = self.workspace.open_file(filename, binary=True)
|
||||
content = decode_textual_file(file, os.path.splitext(filename)[1], logger)
|
||||
|
||||
return content
|
||||
|
||||
@command(
|
||||
["write_file", "create_file"],
|
||||
"Write a file, creating it if necessary. "
|
||||
"If the file exists, it is overwritten.",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The name of the file to write to",
|
||||
required=True,
|
||||
),
|
||||
"contents": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The contents to write to the file",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
async def write_to_file(self, filename: str | Path, contents: str) -> str:
|
||||
"""Write contents to a file
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to write to
|
||||
contents (str): The contents to write to the file
|
||||
|
||||
Returns:
|
||||
str: A message indicating success or failure
|
||||
"""
|
||||
if directory := os.path.dirname(filename):
|
||||
self.workspace.make_dir(directory)
|
||||
await self.workspace.write_file(filename, contents)
|
||||
return f"File {filename} has been written successfully."
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"folder": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The folder to list files in",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def list_folder(self, folder: str | Path) -> list[str]:
|
||||
"""Lists files in a folder recursively
|
||||
|
||||
Args:
|
||||
folder (str): The folder to search in
|
||||
|
||||
Returns:
|
||||
list[str]: A list of files found in the folder
|
||||
"""
|
||||
return [str(p) for p in self.workspace.list_files(folder)]
|
||||
@@ -1,14 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.core.prompting import ChatPrompt
|
||||
from autogpt.models.context_item import ContextItem
|
||||
|
||||
from ..base import BaseAgent
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from autogpt.agents.protocols import CommandProvider, MessageProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.core.resource.model_providers import ChatMessage
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.models.context_item import ContextItem, FileContextItem, FolderContextItem
|
||||
from autogpt.utils.exceptions import InvalidArgumentError
|
||||
|
||||
|
||||
class AgentContext:
|
||||
@@ -32,51 +33,129 @@ class AgentContext:
|
||||
def clear(self) -> None:
|
||||
self.items.clear()
|
||||
|
||||
def format_numbered(self) -> str:
|
||||
return "\n\n".join([f"{i}. {c.fmt()}" for i, c in enumerate(self.items, 1)])
|
||||
def format_numbered(self, workspace: FileStorage) -> str:
|
||||
return "\n\n".join(
|
||||
[f"{i}. {c.fmt(workspace)}" for i, c in enumerate(self.items, 1)]
|
||||
)
|
||||
|
||||
|
||||
class ContextMixin:
|
||||
"""Mixin that adds context support to a BaseAgent subclass"""
|
||||
class ContextComponent(MessageProvider, CommandProvider):
|
||||
"""Adds ability to keep files and folders open in the context (prompt)."""
|
||||
|
||||
context: AgentContext
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
def __init__(self, workspace: FileStorage):
|
||||
self.context = AgentContext()
|
||||
self.workspace = workspace
|
||||
|
||||
super(ContextMixin, self).__init__(**kwargs)
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
*args: Any,
|
||||
extra_messages: Optional[list[ChatMessage]] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatPrompt:
|
||||
if not extra_messages:
|
||||
extra_messages = []
|
||||
|
||||
# Add context section to prompt
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
if self.context:
|
||||
extra_messages.insert(
|
||||
0,
|
||||
ChatMessage.system(
|
||||
"## Context\n"
|
||||
f"{self.context.format_numbered()}\n\n"
|
||||
"When a context item is no longer needed and you are not done yet, "
|
||||
"you can hide the item by specifying its number in the list above "
|
||||
"to `hide_context_item`.",
|
||||
),
|
||||
yield ChatMessage.system(
|
||||
"## Context\n"
|
||||
f"{self.context.format_numbered(self.workspace)}\n\n"
|
||||
"When a context item is no longer needed and you are not done yet, "
|
||||
"you can hide the item by specifying its number in the list above "
|
||||
"to `hide_context_item`.",
|
||||
)
|
||||
|
||||
return super(ContextMixin, self).build_prompt(
|
||||
*args,
|
||||
extra_messages=extra_messages,
|
||||
**kwargs,
|
||||
) # type: ignore
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.open_file
|
||||
yield self.open_folder
|
||||
if self.context:
|
||||
yield self.close_context_item
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"file_path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the file to open",
|
||||
required=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
async def open_file(self, file_path: str | Path) -> str:
|
||||
"""Opens a file for editing or continued viewing;
|
||||
creates it if it does not exist yet.
|
||||
Note: If you only need to read or write a file once,
|
||||
use `write_to_file` instead.
|
||||
|
||||
def get_agent_context(agent: BaseAgent) -> AgentContext | None:
|
||||
if isinstance(agent, ContextMixin):
|
||||
return agent.context
|
||||
Args:
|
||||
file_path (str | Path): The path of the file to open
|
||||
|
||||
return None
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
"""
|
||||
if not isinstance(file_path, Path):
|
||||
file_path = Path(file_path)
|
||||
|
||||
created = False
|
||||
if not self.workspace.exists(file_path):
|
||||
await self.workspace.write_file(file_path, "")
|
||||
created = True
|
||||
|
||||
# Try to make the file path relative
|
||||
with contextlib.suppress(ValueError):
|
||||
file_path = file_path.relative_to(self.workspace.root)
|
||||
|
||||
file = FileContextItem(path=file_path)
|
||||
self.context.add(file)
|
||||
return (
|
||||
f"File {file_path}{' created,' if created else ''} has been opened"
|
||||
" and added to the context ✅"
|
||||
)
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the folder to open",
|
||||
required=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
def open_folder(self, path: str | Path) -> str:
|
||||
"""Open a folder to keep track of its content
|
||||
|
||||
Args:
|
||||
path (str | Path): The path of the folder to open
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
"""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
|
||||
if not self.workspace.exists(path):
|
||||
raise FileNotFoundError(
|
||||
f"open_folder {path} failed: no such file or directory"
|
||||
)
|
||||
|
||||
# Try to make the path relative
|
||||
with contextlib.suppress(ValueError):
|
||||
path = path.relative_to(self.workspace.root)
|
||||
|
||||
folder = FolderContextItem(path=path)
|
||||
self.context.add(folder)
|
||||
return f"Folder {path} has been opened and added to the context ✅"
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"number": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The 1-based index of the context item to hide",
|
||||
required=True,
|
||||
)
|
||||
}
|
||||
)
|
||||
def close_context_item(self, number: int) -> str:
|
||||
"""Hide an open file, folder or other context item, to save tokens.
|
||||
|
||||
Args:
|
||||
number (int): The 1-based index of the context item to hide
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
"""
|
||||
if number > len(self.context.items) or number == 0:
|
||||
raise InvalidArgumentError(f"Index {number} out of range")
|
||||
|
||||
self.context.close(number)
|
||||
return f"Context item {number} hidden ✅"
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
from ..base import BaseAgent, Config
|
||||
|
||||
from autogpt.file_workspace import (
|
||||
FileWorkspace,
|
||||
FileWorkspaceBackendName,
|
||||
get_workspace,
|
||||
)
|
||||
|
||||
from ..base import AgentFileManager, BaseAgentSettings
|
||||
|
||||
|
||||
class FileWorkspaceMixin:
|
||||
"""Mixin that adds workspace support to a class"""
|
||||
|
||||
workspace: FileWorkspace = None
|
||||
"""Workspace that the agent has access to, e.g. for reading/writing files."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Initialize other bases first, because we need the config from BaseAgent
|
||||
super(FileWorkspaceMixin, self).__init__(**kwargs)
|
||||
|
||||
file_manager: AgentFileManager = getattr(self, "file_manager")
|
||||
if not file_manager:
|
||||
return
|
||||
|
||||
self._setup_workspace()
|
||||
|
||||
def attach_fs(self, agent_dir: Path):
|
||||
res = super(FileWorkspaceMixin, self).attach_fs(agent_dir)
|
||||
|
||||
self._setup_workspace()
|
||||
|
||||
return res
|
||||
|
||||
def _setup_workspace(self) -> None:
|
||||
settings: BaseAgentSettings = getattr(self, "state")
|
||||
assert settings.agent_id, "Cannot attach workspace to anonymous agent"
|
||||
app_config: Config = getattr(self, "legacy_config")
|
||||
file_manager: AgentFileManager = getattr(self, "file_manager")
|
||||
|
||||
ws_backend = app_config.workspace_backend
|
||||
local = ws_backend == FileWorkspaceBackendName.LOCAL
|
||||
workspace = get_workspace(
|
||||
backend=ws_backend,
|
||||
id=settings.agent_id if not local else "",
|
||||
root_path=file_manager.root / "workspace" if local else None,
|
||||
)
|
||||
if local and settings.config.allow_fs_access:
|
||||
workspace._restrict_to_root = False # type: ignore
|
||||
workspace.initialize()
|
||||
self.workspace = workspace
|
||||
|
||||
|
||||
def get_agent_workspace(agent: BaseAgent) -> FileWorkspace | None:
|
||||
if isinstance(agent, FileWorkspaceMixin):
|
||||
return agent.workspace
|
||||
|
||||
return None
|
||||
@@ -1,41 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import ExitStack
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..base import BaseAgentConfiguration
|
||||
|
||||
from autogpt.agents.base import BaseAgentActionProposal, BaseAgentConfiguration
|
||||
from autogpt.agents.components import ComponentSystemError
|
||||
from autogpt.agents.features.context import ContextComponent
|
||||
from autogpt.agents.protocols import AfterParse
|
||||
from autogpt.models.action_history import EpisodicActionHistory
|
||||
|
||||
from ..base import BaseAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WatchdogMixin:
|
||||
class WatchdogComponent(AfterParse):
|
||||
"""
|
||||
Mixin that adds a watchdog feature to an agent class. Whenever the agent starts
|
||||
Adds a watchdog feature to an agent class. Whenever the agent starts
|
||||
looping, the watchdog will switch from the FAST_LLM to the SMART_LLM and re-think.
|
||||
"""
|
||||
|
||||
config: BaseAgentConfiguration
|
||||
event_history: EpisodicActionHistory
|
||||
run_after = [ContextComponent]
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# Initialize other bases first, because we need the event_history from BaseAgent
|
||||
super(WatchdogMixin, self).__init__(**kwargs)
|
||||
def __init__(
|
||||
self,
|
||||
config: BaseAgentConfiguration,
|
||||
event_history: EpisodicActionHistory[BaseAgentActionProposal],
|
||||
):
|
||||
self.config = config
|
||||
self.event_history = event_history
|
||||
self.revert_big_brain = False
|
||||
|
||||
if not isinstance(self, BaseAgent):
|
||||
raise NotImplementedError(
|
||||
f"{__class__.__name__} can only be applied to BaseAgent derivatives"
|
||||
)
|
||||
|
||||
async def propose_action(self, *args, **kwargs) -> BaseAgent.ThoughtProcessOutput:
|
||||
command_name, command_args, thoughts = await super(
|
||||
WatchdogMixin, self
|
||||
).propose_action(*args, **kwargs)
|
||||
def after_parse(self, result: BaseAgentActionProposal) -> None:
|
||||
if self.revert_big_brain:
|
||||
self.config.big_brain = False
|
||||
self.revert_big_brain = False
|
||||
|
||||
if not self.config.big_brain and self.config.fast_llm != self.config.smart_llm:
|
||||
previous_command, previous_command_args = None, None
|
||||
@@ -44,33 +38,23 @@ class WatchdogMixin:
|
||||
previous_cycle = self.event_history.episodes[
|
||||
self.event_history.cursor - 1
|
||||
]
|
||||
previous_command = previous_cycle.action.name
|
||||
previous_command_args = previous_cycle.action.args
|
||||
previous_command = previous_cycle.action.use_tool.name
|
||||
previous_command_args = previous_cycle.action.use_tool.arguments
|
||||
|
||||
rethink_reason = ""
|
||||
|
||||
if not command_name:
|
||||
if not result.use_tool:
|
||||
rethink_reason = "AI did not specify a command"
|
||||
elif (
|
||||
command_name == previous_command
|
||||
and command_args == previous_command_args
|
||||
result.use_tool.name == previous_command
|
||||
and result.use_tool.arguments == previous_command_args
|
||||
):
|
||||
rethink_reason = f"Repititive command detected ({command_name})"
|
||||
rethink_reason = f"Repititive command detected ({result.use_tool.name})"
|
||||
|
||||
if rethink_reason:
|
||||
logger.info(f"{rethink_reason}, re-thinking with SMART_LLM...")
|
||||
with ExitStack() as stack:
|
||||
|
||||
@stack.callback
|
||||
def restore_state() -> None:
|
||||
# Executed after exiting the ExitStack context
|
||||
self.config.big_brain = False
|
||||
|
||||
# Remove partial record of current cycle
|
||||
self.event_history.rewind()
|
||||
|
||||
# Switch to SMART_LLM and re-think
|
||||
self.big_brain = True
|
||||
return await self.propose_action(*args, **kwargs)
|
||||
|
||||
return command_name, command_args, thoughts
|
||||
self.event_history.rewind()
|
||||
self.big_brain = True
|
||||
self.revert_big_brain = True
|
||||
# Trigger retry of all pipelines prior to this component
|
||||
raise ComponentSystemError()
|
||||
|
||||
@@ -4,15 +4,11 @@ import json
|
||||
import platform
|
||||
import re
|
||||
from logging import Logger
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
import distro
|
||||
from pydantic import Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.models.action_history import Episode
|
||||
|
||||
from autogpt.agents.utils.exceptions import InvalidAgentResponseError
|
||||
from autogpt.agents.base import BaseAgentActionProposal
|
||||
from autogpt.config import AIDirectives, AIProfile
|
||||
from autogpt.core.configuration.schema import SystemConfiguration, UserConfigurable
|
||||
from autogpt.core.prompting import (
|
||||
@@ -26,8 +22,32 @@ from autogpt.core.resource.model_providers.schema import (
|
||||
CompletionModelFunction,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.json_utils.utilities import extract_dict_from_response
|
||||
from autogpt.prompts.utils import format_numbered_list, indent
|
||||
from autogpt.core.utils.json_utils import extract_dict_from_json
|
||||
from autogpt.models.utils import ModelWithSummary
|
||||
from autogpt.prompts.utils import format_numbered_list
|
||||
from autogpt.utils.exceptions import InvalidAgentResponseError
|
||||
|
||||
_RESPONSE_INTERFACE_NAME = "AssistantResponse"
|
||||
|
||||
|
||||
class AssistantThoughts(ModelWithSummary):
|
||||
observations: str = Field(
|
||||
..., description="Relevant observations from your last action (if any)"
|
||||
)
|
||||
text: str = Field(..., description="Thoughts")
|
||||
reasoning: str = Field(..., description="Reasoning behind the thoughts")
|
||||
self_criticism: str = Field(..., description="Constructive self-criticism")
|
||||
plan: list[str] = Field(
|
||||
..., description="Short list that conveys the long-term plan"
|
||||
)
|
||||
speak: str = Field(..., description="Summary of thoughts, to say to user")
|
||||
|
||||
def summary(self) -> str:
|
||||
return self.text
|
||||
|
||||
|
||||
class OneShotAgentActionProposal(BaseAgentActionProposal):
|
||||
thoughts: AssistantThoughts
|
||||
|
||||
|
||||
class OneShotAgentPromptConfiguration(SystemConfiguration):
|
||||
@@ -55,70 +75,7 @@ class OneShotAgentPromptConfiguration(SystemConfiguration):
|
||||
"and respond using the JSON schema specified previously:"
|
||||
)
|
||||
|
||||
DEFAULT_RESPONSE_SCHEMA = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={
|
||||
"thoughts": JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
required=True,
|
||||
properties={
|
||||
"observations": JSONSchema(
|
||||
description=(
|
||||
"Relevant observations from your last action (if any)"
|
||||
),
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=False,
|
||||
),
|
||||
"text": JSONSchema(
|
||||
description="Thoughts",
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
"reasoning": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
"self_criticism": JSONSchema(
|
||||
description="Constructive self-criticism",
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
"plan": JSONSchema(
|
||||
description=(
|
||||
"Short markdown-style bullet list that conveys the "
|
||||
"long-term plan"
|
||||
),
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
"speak": JSONSchema(
|
||||
description="Summary of thoughts, to say to user",
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
"command": JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
required=True,
|
||||
properties={
|
||||
"name": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
required=True,
|
||||
),
|
||||
"args": JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
body_template: str = UserConfigurable(default=DEFAULT_BODY_TEMPLATE)
|
||||
response_schema: dict = UserConfigurable(
|
||||
default_factory=DEFAULT_RESPONSE_SCHEMA.to_dict
|
||||
)
|
||||
choose_action_instruction: str = UserConfigurable(
|
||||
default=DEFAULT_CHOOSE_ACTION_INSTRUCTION
|
||||
)
|
||||
@@ -143,7 +100,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
logger: Logger,
|
||||
):
|
||||
self.config = configuration
|
||||
self.response_schema = JSONSchema.from_dict(configuration.response_schema)
|
||||
self.response_schema = JSONSchema.from_dict(OneShotAgentActionProposal.schema())
|
||||
self.logger = logger
|
||||
|
||||
@property
|
||||
@@ -153,81 +110,55 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
def build_prompt(
|
||||
self,
|
||||
*,
|
||||
messages: list[ChatMessage],
|
||||
task: str,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
event_history: list[Episode],
|
||||
include_os_info: bool,
|
||||
max_prompt_tokens: int,
|
||||
count_tokens: Callable[[str], int],
|
||||
count_message_tokens: Callable[[ChatMessage | list[ChatMessage]], int],
|
||||
extra_messages: Optional[list[ChatMessage]] = None,
|
||||
**extras,
|
||||
) -> ChatPrompt:
|
||||
"""Constructs and returns a prompt with the following structure:
|
||||
1. System prompt
|
||||
2. Message history of the agent, truncated & prepended with running summary
|
||||
as needed
|
||||
3. `cycle_instruction`
|
||||
"""
|
||||
if not extra_messages:
|
||||
extra_messages = []
|
||||
|
||||
system_prompt = self.build_system_prompt(
|
||||
system_prompt, response_prefill = self.build_system_prompt(
|
||||
ai_profile=ai_profile,
|
||||
ai_directives=ai_directives,
|
||||
commands=commands,
|
||||
include_os_info=include_os_info,
|
||||
)
|
||||
system_prompt_tlength = count_message_tokens(ChatMessage.system(system_prompt))
|
||||
|
||||
user_task = f'"""{task}"""'
|
||||
user_task_tlength = count_message_tokens(ChatMessage.user(user_task))
|
||||
|
||||
response_format_instr = self.response_format_instruction(
|
||||
self.config.use_functions_api
|
||||
)
|
||||
extra_messages.append(ChatMessage.system(response_format_instr))
|
||||
|
||||
final_instruction_msg = ChatMessage.user(self.config.choose_action_instruction)
|
||||
final_instruction_tlength = count_message_tokens(final_instruction_msg)
|
||||
|
||||
if event_history:
|
||||
progress = self.compile_progress(
|
||||
event_history,
|
||||
count_tokens=count_tokens,
|
||||
max_tokens=(
|
||||
max_prompt_tokens
|
||||
- system_prompt_tlength
|
||||
- user_task_tlength
|
||||
- final_instruction_tlength
|
||||
- count_message_tokens(extra_messages)
|
||||
),
|
||||
)
|
||||
extra_messages.insert(
|
||||
0,
|
||||
ChatMessage.system(f"## Progress\n\n{progress}"),
|
||||
)
|
||||
|
||||
prompt = ChatPrompt(
|
||||
return ChatPrompt(
|
||||
messages=[
|
||||
ChatMessage.system(system_prompt),
|
||||
ChatMessage.user(user_task),
|
||||
*extra_messages,
|
||||
ChatMessage.user(f'"""{task}"""'),
|
||||
*messages,
|
||||
final_instruction_msg,
|
||||
],
|
||||
prefill_response=response_prefill,
|
||||
functions=commands if self.config.use_functions_api else [],
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
def build_system_prompt(
|
||||
self,
|
||||
ai_profile: AIProfile,
|
||||
ai_directives: AIDirectives,
|
||||
commands: list[CompletionModelFunction],
|
||||
include_os_info: bool,
|
||||
) -> str:
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Builds the system prompt.
|
||||
|
||||
Returns:
|
||||
str: The system prompt body
|
||||
str: The desired start for the LLM's response; used to steer the output
|
||||
"""
|
||||
response_fmt_instruction, response_prefill = self.response_format_instruction(
|
||||
self.config.use_functions_api
|
||||
)
|
||||
system_prompt_parts = (
|
||||
self._generate_intro_prompt(ai_profile)
|
||||
+ (self._generate_os_info() if include_os_info else [])
|
||||
@@ -248,82 +179,39 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
" in the next message. Your job is to complete the task while following"
|
||||
" your directives as given above, and terminate when your task is done."
|
||||
]
|
||||
+ ["## RESPONSE FORMAT\n" + response_fmt_instruction]
|
||||
)
|
||||
|
||||
# Join non-empty parts together into paragraph format
|
||||
return "\n\n".join(filter(None, system_prompt_parts)).strip("\n")
|
||||
return (
|
||||
"\n\n".join(filter(None, system_prompt_parts)).strip("\n"),
|
||||
response_prefill,
|
||||
)
|
||||
|
||||
def compile_progress(
|
||||
self,
|
||||
episode_history: list[Episode],
|
||||
max_tokens: Optional[int] = None,
|
||||
count_tokens: Optional[Callable[[str], int]] = None,
|
||||
) -> str:
|
||||
if max_tokens and not count_tokens:
|
||||
raise ValueError("count_tokens is required if max_tokens is set")
|
||||
|
||||
steps: list[str] = []
|
||||
tokens: int = 0
|
||||
# start: int = len(episode_history)
|
||||
|
||||
for i, c in reversed(list(enumerate(episode_history))):
|
||||
step = f"### Step {i+1}: Executed `{c.action.format_call()}`\n"
|
||||
step += f'- **Reasoning:** "{c.action.reasoning}"\n'
|
||||
step += (
|
||||
f"- **Status:** `{c.result.status if c.result else 'did_not_finish'}`\n"
|
||||
)
|
||||
if c.result:
|
||||
if c.result.status == "success":
|
||||
result = str(c.result)
|
||||
result = "\n" + indent(result) if "\n" in result else result
|
||||
step += f"- **Output:** {result}"
|
||||
elif c.result.status == "error":
|
||||
step += f"- **Reason:** {c.result.reason}\n"
|
||||
if c.result.error:
|
||||
step += f"- **Error:** {c.result.error}\n"
|
||||
elif c.result.status == "interrupted_by_human":
|
||||
step += f"- **Feedback:** {c.result.feedback}\n"
|
||||
|
||||
if max_tokens and count_tokens:
|
||||
step_tokens = count_tokens(step)
|
||||
if tokens + step_tokens > max_tokens:
|
||||
break
|
||||
tokens += step_tokens
|
||||
|
||||
steps.insert(0, step)
|
||||
# start = i
|
||||
|
||||
# # TODO: summarize remaining
|
||||
# part = slice(0, start)
|
||||
|
||||
return "\n\n".join(steps)
|
||||
|
||||
def response_format_instruction(self, use_functions_api: bool) -> str:
|
||||
def response_format_instruction(self, use_functions_api: bool) -> tuple[str, str]:
|
||||
response_schema = self.response_schema.copy(deep=True)
|
||||
if (
|
||||
use_functions_api
|
||||
and response_schema.properties
|
||||
and "command" in response_schema.properties
|
||||
and "use_tool" in response_schema.properties
|
||||
):
|
||||
del response_schema.properties["command"]
|
||||
del response_schema.properties["use_tool"]
|
||||
|
||||
# Unindent for performance
|
||||
response_format = re.sub(
|
||||
r"\n\s+",
|
||||
"\n",
|
||||
response_schema.to_typescript_object_interface("Response"),
|
||||
)
|
||||
|
||||
instruction = (
|
||||
"Respond with pure JSON containing your thoughts, " "and invoke a tool."
|
||||
if use_functions_api
|
||||
else "Respond with pure JSON."
|
||||
response_schema.to_typescript_object_interface(_RESPONSE_INTERFACE_NAME),
|
||||
)
|
||||
response_prefill = f'{{\n "{list(response_schema.properties.keys())[0]}":'
|
||||
|
||||
return (
|
||||
f"{instruction} "
|
||||
"The JSON object should be compatible with the TypeScript type `Response` "
|
||||
f"from the following:\n{response_format}"
|
||||
(
|
||||
f"YOU MUST ALWAYS RESPOND WITH A JSON OBJECT OF THE FOLLOWING TYPE:\n"
|
||||
f"{response_format}"
|
||||
+ ("\n\nYOU MUST ALSO INVOKE A TOOL!" if use_functions_api else "")
|
||||
),
|
||||
response_prefill,
|
||||
)
|
||||
|
||||
def _generate_intro_prompt(self, ai_profile: AIProfile) -> list[str]:
|
||||
@@ -387,7 +275,7 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
def parse_response_content(
|
||||
self,
|
||||
response: AssistantChatMessage,
|
||||
) -> Agent.ThoughtProcessOutput:
|
||||
) -> OneShotAgentActionProposal:
|
||||
if not response.content:
|
||||
raise InvalidAgentResponseError("Assistant response has no text content")
|
||||
|
||||
@@ -399,86 +287,15 @@ class OneShotAgentPromptStrategy(PromptStrategy):
|
||||
else f" '{response.content}'"
|
||||
)
|
||||
)
|
||||
assistant_reply_dict = extract_dict_from_response(response.content)
|
||||
assistant_reply_dict = extract_dict_from_json(response.content)
|
||||
self.logger.debug(
|
||||
"Validating object extracted from LLM response:\n"
|
||||
"Parsing object extracted from LLM response:\n"
|
||||
f"{json.dumps(assistant_reply_dict, indent=4)}"
|
||||
)
|
||||
|
||||
_, errors = self.response_schema.validate_object(
|
||||
object=assistant_reply_dict,
|
||||
logger=self.logger,
|
||||
)
|
||||
if errors:
|
||||
raise InvalidAgentResponseError(
|
||||
"Validation of response failed:\n "
|
||||
+ ";\n ".join([str(e) for e in errors])
|
||||
)
|
||||
|
||||
# Get command name and arguments
|
||||
command_name, arguments = extract_command(
|
||||
assistant_reply_dict, response, self.config.use_functions_api
|
||||
)
|
||||
return command_name, arguments, assistant_reply_dict
|
||||
|
||||
|
||||
#############
|
||||
# Utilities #
|
||||
#############
|
||||
|
||||
|
||||
def extract_command(
|
||||
assistant_reply_json: dict,
|
||||
assistant_reply: AssistantChatMessage,
|
||||
use_openai_functions_api: bool,
|
||||
) -> tuple[str, dict[str, str]]:
|
||||
"""Parse the response and return the command name and arguments
|
||||
|
||||
Args:
|
||||
assistant_reply_json (dict): The response object from the AI
|
||||
assistant_reply (AssistantChatMessage): The model response from the AI
|
||||
config (Config): The config object
|
||||
|
||||
Returns:
|
||||
tuple: The command name and arguments
|
||||
|
||||
Raises:
|
||||
json.decoder.JSONDecodeError: If the response is not valid JSON
|
||||
|
||||
Exception: If any other error occurs
|
||||
"""
|
||||
if use_openai_functions_api:
|
||||
if not assistant_reply.tool_calls:
|
||||
raise InvalidAgentResponseError("No 'tool_calls' in assistant reply")
|
||||
assistant_reply_json["command"] = {
|
||||
"name": assistant_reply.tool_calls[0].function.name,
|
||||
"args": json.loads(assistant_reply.tool_calls[0].function.arguments),
|
||||
}
|
||||
try:
|
||||
if not isinstance(assistant_reply_json, dict):
|
||||
raise InvalidAgentResponseError(
|
||||
f"The previous message sent was not a dictionary {assistant_reply_json}"
|
||||
)
|
||||
|
||||
if "command" not in assistant_reply_json:
|
||||
raise InvalidAgentResponseError("Missing 'command' object in JSON")
|
||||
|
||||
command = assistant_reply_json["command"]
|
||||
if not isinstance(command, dict):
|
||||
raise InvalidAgentResponseError("'command' object is not a dictionary")
|
||||
|
||||
if "name" not in command:
|
||||
raise InvalidAgentResponseError("Missing 'name' field in 'command' object")
|
||||
|
||||
command_name = command["name"]
|
||||
|
||||
# Use an empty dictionary if 'args' field is not present in 'command' object
|
||||
arguments = command.get("args", {})
|
||||
|
||||
return command_name, arguments
|
||||
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise InvalidAgentResponseError("Invalid JSON")
|
||||
|
||||
except Exception as e:
|
||||
raise InvalidAgentResponseError(str(e))
|
||||
parsed_response = OneShotAgentActionProposal.parse_obj(assistant_reply_dict)
|
||||
if self.config.use_functions_api:
|
||||
if not response.tool_calls:
|
||||
raise InvalidAgentResponseError("Assistant did not use a tool")
|
||||
parsed_response.use_tool = response.tool_calls[0].function
|
||||
return parsed_response
|
||||
|
||||
51
autogpts/autogpt/autogpt/agents/protocols.py
Normal file
51
autogpts/autogpt/autogpt/agents/protocols.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Iterator
|
||||
|
||||
from autogpt.agents.components import AgentComponent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.base import BaseAgentActionProposal
|
||||
from autogpt.core.resource.model_providers.schema import ChatMessage
|
||||
from autogpt.models.action_history import ActionResult
|
||||
from autogpt.models.command import Command
|
||||
|
||||
|
||||
class DirectiveProvider(AgentComponent):
|
||||
def get_constraints(self) -> Iterator[str]:
|
||||
return iter([])
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
return iter([])
|
||||
|
||||
def get_best_practices(self) -> Iterator[str]:
|
||||
return iter([])
|
||||
|
||||
|
||||
class CommandProvider(AgentComponent):
|
||||
@abstractmethod
|
||||
def get_commands(self) -> Iterator["Command"]:
|
||||
...
|
||||
|
||||
|
||||
class MessageProvider(AgentComponent):
|
||||
@abstractmethod
|
||||
def get_messages(self) -> Iterator["ChatMessage"]:
|
||||
...
|
||||
|
||||
|
||||
class AfterParse(AgentComponent):
|
||||
@abstractmethod
|
||||
def after_parse(self, result: "BaseAgentActionProposal") -> None:
|
||||
...
|
||||
|
||||
|
||||
class ExecutionFailure(AgentComponent):
|
||||
@abstractmethod
|
||||
def execution_failure(self, error: Exception) -> None:
|
||||
...
|
||||
|
||||
|
||||
class AfterExecute(AgentComponent):
|
||||
@abstractmethod
|
||||
def after_execute(self, result: "ActionResult") -> None:
|
||||
...
|
||||
@@ -1,37 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentFileManager:
|
||||
"""A class that represents a workspace for an AutoGPT agent."""
|
||||
|
||||
def __init__(self, agent_data_dir: Path):
|
||||
self._root = agent_data_dir.resolve()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the workspace."""
|
||||
return self._root
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
self.init_file_ops_log(self.file_ops_log_path)
|
||||
|
||||
@property
|
||||
def state_file_path(self) -> Path:
|
||||
return self.root / "state.json"
|
||||
|
||||
@property
|
||||
def file_ops_log_path(self) -> Path:
|
||||
return self.root / "file_logger.log"
|
||||
|
||||
@staticmethod
|
||||
def init_file_ops_log(file_logger_path: Path) -> Path:
|
||||
if not file_logger_path.exists():
|
||||
with file_logger_path.open(mode="w", encoding="utf-8") as f:
|
||||
f.write("")
|
||||
return file_logger_path
|
||||
@@ -1,108 +0,0 @@
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogpt.core.resource.model_providers.schema import CompletionModelFunction
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
logger = logging.getLogger("PromptScratchpad")
|
||||
|
||||
|
||||
class CallableCompletionModelFunction(CompletionModelFunction):
|
||||
method: Callable
|
||||
|
||||
|
||||
class PromptScratchpad(BaseModel):
|
||||
commands: dict[str, CallableCompletionModelFunction] = Field(default_factory=dict)
|
||||
resources: list[str] = Field(default_factory=list)
|
||||
constraints: list[str] = Field(default_factory=list)
|
||||
best_practices: list[str] = Field(default_factory=list)
|
||||
|
||||
def add_constraint(self, constraint: str) -> None:
|
||||
"""
|
||||
Add a constraint to the constraints list.
|
||||
|
||||
Params:
|
||||
constraint (str): The constraint to be added.
|
||||
"""
|
||||
if constraint not in self.constraints:
|
||||
self.constraints.append(constraint)
|
||||
|
||||
def add_command(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
params: dict[str, str | dict],
|
||||
function: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
Registers a command.
|
||||
|
||||
*Should only be used by plugins.* Native commands should be added
|
||||
directly to the CommandRegistry.
|
||||
|
||||
Params:
|
||||
name (str): The name of the command (e.g. `command_name`).
|
||||
description (str): The description of the command.
|
||||
params (dict, optional): A dictionary containing argument names and their
|
||||
types. Defaults to an empty dictionary.
|
||||
function (callable, optional): A callable function to be called when
|
||||
the command is executed. Defaults to None.
|
||||
"""
|
||||
for p, s in params.items():
|
||||
invalid = False
|
||||
if type(s) is str and s not in JSONSchema.Type._value2member_map_:
|
||||
invalid = True
|
||||
logger.warning(
|
||||
f"Cannot add command '{name}':"
|
||||
f" parameter '{p}' has invalid type '{s}'."
|
||||
f" Valid types are: {JSONSchema.Type._value2member_map_.keys()}"
|
||||
)
|
||||
elif isinstance(s, dict):
|
||||
try:
|
||||
JSONSchema.from_dict(s)
|
||||
except KeyError:
|
||||
invalid = True
|
||||
if invalid:
|
||||
return
|
||||
|
||||
command = CallableCompletionModelFunction(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters={
|
||||
name: JSONSchema(type=JSONSchema.Type._value2member_map_[spec])
|
||||
if type(spec) is str
|
||||
else JSONSchema.from_dict(spec)
|
||||
for name, spec in params.items()
|
||||
},
|
||||
method=function,
|
||||
)
|
||||
|
||||
if name in self.commands:
|
||||
if description == self.commands[name].description:
|
||||
return
|
||||
logger.warning(
|
||||
f"Replacing command {self.commands[name]} with conflicting {command}"
|
||||
)
|
||||
self.commands[name] = command
|
||||
|
||||
def add_resource(self, resource: str) -> None:
|
||||
"""
|
||||
Add a resource to the resources list.
|
||||
|
||||
Params:
|
||||
resource (str): The resource to be added.
|
||||
"""
|
||||
if resource not in self.resources:
|
||||
self.resources.append(resource)
|
||||
|
||||
def add_best_practice(self, best_practice: str) -> None:
|
||||
"""
|
||||
Add an item to the list of best practices.
|
||||
|
||||
Params:
|
||||
best_practice (str): The best practice item to be added.
|
||||
"""
|
||||
if best_practice not in self.best_practices:
|
||||
self.best_practices.append(best_practice)
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
from collections import defaultdict
|
||||
from io import BytesIO
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -25,23 +26,18 @@ from forge.sdk.model import (
|
||||
from forge.sdk.routes.agent_protocol import base_router
|
||||
from hypercorn.asyncio import serve as hypercorn_serve
|
||||
from hypercorn.config import Config as HypercornConfig
|
||||
from sentry_sdk import set_user
|
||||
|
||||
from autogpt.agent_factory.configurators import configure_agent_with_state
|
||||
from autogpt.agent_factory.generators import generate_agent_for_task
|
||||
from autogpt.agent_manager import AgentManager
|
||||
from autogpt.commands.system import finish
|
||||
from autogpt.commands.user_interaction import ask_user
|
||||
from autogpt.app.utils import is_port_free
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIProvider
|
||||
from autogpt.core.resource.model_providers.schema import ModelProviderBudget
|
||||
from autogpt.file_workspace import (
|
||||
FileWorkspace,
|
||||
FileWorkspaceBackendName,
|
||||
get_workspace,
|
||||
)
|
||||
from autogpt.logs.utils import fmt_kwargs
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider, ModelProviderBudget
|
||||
from autogpt.file_storage import FileStorage
|
||||
from autogpt.models.action_history import ActionErrorResult, ActionSuccessResult
|
||||
from autogpt.utils.exceptions import AgentFinished
|
||||
from autogpt.utils.utils import DEFAULT_ASK_COMMAND, DEFAULT_FINISH_COMMAND
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,17 +49,27 @@ class AgentProtocolServer:
|
||||
self,
|
||||
app_config: Config,
|
||||
database: AgentDB,
|
||||
file_storage: FileStorage,
|
||||
llm_provider: ChatModelProvider,
|
||||
):
|
||||
self.app_config = app_config
|
||||
self.db = database
|
||||
self.file_storage = file_storage
|
||||
self.llm_provider = llm_provider
|
||||
self.agent_manager = AgentManager(app_data_dir=app_config.app_data_dir)
|
||||
self._task_budgets = {}
|
||||
self.agent_manager = AgentManager(file_storage)
|
||||
self._task_budgets = defaultdict(ModelProviderBudget)
|
||||
|
||||
async def start(self, port: int = 8000, router: APIRouter = base_router):
|
||||
"""Start the agent server."""
|
||||
logger.debug("Starting the agent server...")
|
||||
if not is_port_free(port):
|
||||
logger.error(f"Port {port} is already in use.")
|
||||
logger.info(
|
||||
"You can specify a port by either setting the AP_SERVER_PORT "
|
||||
"environment variable or defining AP_SERVER_PORT in the .env file."
|
||||
)
|
||||
return
|
||||
|
||||
config = HypercornConfig()
|
||||
config.bind = [f"localhost:{port}"]
|
||||
app = FastAPI(
|
||||
@@ -73,11 +79,14 @@ class AgentProtocolServer:
|
||||
version="v0.4",
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
origins = [
|
||||
"*",
|
||||
# Add any other origins you want to whitelist
|
||||
# Configure CORS middleware
|
||||
default_origins = [f"http://localhost:{port}"] # Default only local access
|
||||
configured_origins = [
|
||||
origin
|
||||
for origin in os.getenv("AP_SERVER_CORS_ALLOWED_ORIGINS", "").split(",")
|
||||
if origin # Empty list if not configured
|
||||
]
|
||||
origins = configured_origins or default_origins
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -121,22 +130,22 @@ class AgentProtocolServer:
|
||||
"""
|
||||
Create a task for the agent.
|
||||
"""
|
||||
if user_id := (task_request.additional_input or {}).get("user_id"):
|
||||
set_user({"id": user_id})
|
||||
|
||||
task = await self.db.create_task(
|
||||
input=task_request.input,
|
||||
additional_input=task_request.additional_input,
|
||||
)
|
||||
logger.debug(f"Creating agent for task: '{task.input}'")
|
||||
task_agent = await generate_agent_for_task(
|
||||
agent_id=task_agent_id(task.task_id),
|
||||
task=task.input,
|
||||
app_config=self.app_config,
|
||||
file_storage=self.file_storage,
|
||||
llm_provider=self._get_task_llm_provider(task),
|
||||
)
|
||||
|
||||
# Assign an ID and a folder to the Agent and persist it
|
||||
agent_id = task_agent.state.agent_id = task_agent_id(task.task_id)
|
||||
logger.debug(f"New agent ID: {agent_id}")
|
||||
task_agent.attach_fs(self.app_config.app_data_dir / "agents" / agent_id)
|
||||
task_agent.state.save_to_json_file(task_agent.file_manager.state_file_path)
|
||||
await task_agent.file_manager.save_state()
|
||||
|
||||
return task
|
||||
|
||||
@@ -175,17 +184,21 @@ class AgentProtocolServer:
|
||||
# Restore Agent instance
|
||||
task = await self.get_task(task_id)
|
||||
agent = configure_agent_with_state(
|
||||
state=self.agent_manager.retrieve_state(task_agent_id(task_id)),
|
||||
state=self.agent_manager.load_agent_state(task_agent_id(task_id)),
|
||||
app_config=self.app_config,
|
||||
file_storage=self.file_storage,
|
||||
llm_provider=self._get_task_llm_provider(task),
|
||||
)
|
||||
|
||||
if user_id := (task.additional_input or {}).get("user_id"):
|
||||
set_user({"id": user_id})
|
||||
|
||||
# According to the Agent Protocol spec, the first execute_step request contains
|
||||
# the same task input as the parent create_task request.
|
||||
# To prevent this from interfering with the agent's process, we ignore the input
|
||||
# of this first step request, and just generate the first step proposal.
|
||||
is_init_step = not bool(agent.event_history)
|
||||
execute_command, execute_command_args, execute_result = None, None, None
|
||||
last_proposal, tool_result = None, None
|
||||
execute_approved = False
|
||||
|
||||
# HACK: only for compatibility with AGBenchmark
|
||||
@@ -199,13 +212,11 @@ class AgentProtocolServer:
|
||||
and agent.event_history.current_episode
|
||||
and not agent.event_history.current_episode.result
|
||||
):
|
||||
execute_command = agent.event_history.current_episode.action.name
|
||||
execute_command_args = agent.event_history.current_episode.action.args
|
||||
last_proposal = agent.event_history.current_episode.action
|
||||
execute_approved = not user_input
|
||||
|
||||
logger.debug(
|
||||
f"Agent proposed command"
|
||||
f" {execute_command}({fmt_kwargs(execute_command_args)})."
|
||||
f"Agent proposed command {last_proposal.use_tool}."
|
||||
f" User input/feedback: {repr(user_input)}"
|
||||
)
|
||||
|
||||
@@ -213,58 +224,62 @@ class AgentProtocolServer:
|
||||
step = await self.db.create_step(
|
||||
task_id=task_id,
|
||||
input=step_request,
|
||||
is_last=execute_command == finish.__name__ and execute_approved,
|
||||
is_last=(
|
||||
last_proposal is not None
|
||||
and last_proposal.use_tool.name == DEFAULT_FINISH_COMMAND
|
||||
and execute_approved
|
||||
),
|
||||
)
|
||||
agent.llm_provider = self._get_task_llm_provider(task, step.step_id)
|
||||
|
||||
# Execute previously proposed action
|
||||
if execute_command:
|
||||
assert execute_command_args is not None
|
||||
agent.workspace.on_write_file = lambda path: self._on_agent_write_file(
|
||||
task=task, step=step, relative_path=path
|
||||
if last_proposal:
|
||||
agent.file_manager.workspace.on_write_file = (
|
||||
lambda path: self._on_agent_write_file(
|
||||
task=task, step=step, relative_path=path
|
||||
)
|
||||
)
|
||||
|
||||
if step.is_last and execute_command == finish.__name__:
|
||||
assert execute_command_args
|
||||
step = await self.db.update_step(
|
||||
task_id=task_id,
|
||||
step_id=step.step_id,
|
||||
output=execute_command_args["reason"],
|
||||
)
|
||||
logger.info(
|
||||
f"Total LLM cost for task {task_id}: "
|
||||
f"${round(agent.llm_provider.get_incurred_cost(), 2)}"
|
||||
)
|
||||
return step
|
||||
|
||||
if execute_command == ask_user.__name__: # HACK
|
||||
execute_result = ActionSuccessResult(outputs=user_input)
|
||||
agent.event_history.register_result(execute_result)
|
||||
elif not execute_command:
|
||||
execute_result = None
|
||||
if last_proposal.use_tool.name == DEFAULT_ASK_COMMAND:
|
||||
tool_result = ActionSuccessResult(outputs=user_input)
|
||||
agent.event_history.register_result(tool_result)
|
||||
elif execute_approved:
|
||||
step = await self.db.update_step(
|
||||
task_id=task_id,
|
||||
step_id=step.step_id,
|
||||
status="running",
|
||||
)
|
||||
# Execute previously proposed action
|
||||
execute_result = await agent.execute(
|
||||
command_name=execute_command,
|
||||
command_args=execute_command_args,
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute previously proposed action
|
||||
tool_result = await agent.execute(last_proposal)
|
||||
except AgentFinished:
|
||||
additional_output = {}
|
||||
task_total_cost = agent.llm_provider.get_incurred_cost()
|
||||
if task_total_cost > 0:
|
||||
additional_output["task_total_cost"] = task_total_cost
|
||||
logger.info(
|
||||
f"Total LLM cost for task {task_id}: "
|
||||
f"${round(task_total_cost, 2)}"
|
||||
)
|
||||
|
||||
step = await self.db.update_step(
|
||||
task_id=task_id,
|
||||
step_id=step.step_id,
|
||||
output=last_proposal.use_tool.arguments["reason"],
|
||||
additional_output=additional_output,
|
||||
)
|
||||
await agent.file_manager.save_state()
|
||||
return step
|
||||
else:
|
||||
assert user_input
|
||||
execute_result = await agent.execute(
|
||||
command_name="human_feedback", # HACK
|
||||
command_args={},
|
||||
user_input=user_input,
|
||||
)
|
||||
tool_result = await agent.do_not_execute(last_proposal, user_input)
|
||||
|
||||
# Propose next action
|
||||
try:
|
||||
next_command, next_command_args, raw_output = await agent.propose_action()
|
||||
logger.debug(f"AI output: {raw_output}")
|
||||
assistant_response = await agent.propose_action()
|
||||
next_tool_to_use = assistant_response.use_tool
|
||||
logger.debug(f"AI output: {assistant_response.thoughts}")
|
||||
except Exception as e:
|
||||
step = await self.db.update_step(
|
||||
task_id=task_id,
|
||||
@@ -277,42 +292,54 @@ class AgentProtocolServer:
|
||||
# Format step output
|
||||
output = (
|
||||
(
|
||||
f"`{execute_command}({fmt_kwargs(execute_command_args)})` returned:"
|
||||
+ ("\n\n" if "\n" in str(execute_result) else " ")
|
||||
+ f"{execute_result}\n\n"
|
||||
f"`{last_proposal.use_tool}` returned:"
|
||||
+ ("\n\n" if "\n" in str(tool_result) else " ")
|
||||
+ f"{tool_result}\n\n"
|
||||
)
|
||||
if execute_command_args and execute_command != ask_user.__name__
|
||||
if last_proposal and last_proposal.use_tool.name != DEFAULT_ASK_COMMAND
|
||||
else ""
|
||||
)
|
||||
output += f"{raw_output['thoughts']['speak']}\n\n"
|
||||
output += f"{assistant_response.thoughts.speak}\n\n"
|
||||
output += (
|
||||
f"Next Command: {next_command}({fmt_kwargs(next_command_args)})"
|
||||
if next_command != ask_user.__name__
|
||||
else next_command_args["question"]
|
||||
f"Next Command: {next_tool_to_use}"
|
||||
if next_tool_to_use.name != DEFAULT_ASK_COMMAND
|
||||
else next_tool_to_use.arguments["question"]
|
||||
)
|
||||
|
||||
additional_output = {
|
||||
**(
|
||||
{
|
||||
"last_action": {
|
||||
"name": execute_command,
|
||||
"args": execute_command_args,
|
||||
"name": last_proposal.use_tool.name,
|
||||
"args": last_proposal.use_tool.arguments,
|
||||
"result": (
|
||||
orjson.loads(execute_result.json())
|
||||
if not isinstance(execute_result, ActionErrorResult)
|
||||
else {
|
||||
"error": str(execute_result.error),
|
||||
"reason": execute_result.reason,
|
||||
}
|
||||
""
|
||||
if tool_result is None
|
||||
else (
|
||||
orjson.loads(tool_result.json())
|
||||
if not isinstance(tool_result, ActionErrorResult)
|
||||
else {
|
||||
"error": str(tool_result.error),
|
||||
"reason": tool_result.reason,
|
||||
}
|
||||
)
|
||||
),
|
||||
},
|
||||
}
|
||||
if not is_init_step
|
||||
if last_proposal and tool_result
|
||||
else {}
|
||||
),
|
||||
**raw_output,
|
||||
**assistant_response.dict(),
|
||||
}
|
||||
|
||||
task_cumulative_cost = agent.llm_provider.get_incurred_cost()
|
||||
if task_cumulative_cost > 0:
|
||||
additional_output["task_cumulative_cost"] = task_cumulative_cost
|
||||
logger.debug(
|
||||
f"Running total LLM cost for task {task_id}: "
|
||||
f"${round(task_cumulative_cost, 3)}"
|
||||
)
|
||||
|
||||
step = await self.db.update_step(
|
||||
task_id=task_id,
|
||||
step_id=step.step_id,
|
||||
@@ -321,11 +348,7 @@ class AgentProtocolServer:
|
||||
additional_output=additional_output,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Running total LLM cost for task {task_id}: "
|
||||
f"${round(agent.llm_provider.get_incurred_cost(), 3)}"
|
||||
)
|
||||
agent.state.save_to_json_file(agent.file_manager.state_file_path)
|
||||
await agent.file_manager.save_state()
|
||||
return step
|
||||
|
||||
async def _on_agent_write_file(
|
||||
@@ -384,7 +407,7 @@ class AgentProtocolServer:
|
||||
else:
|
||||
file_path = os.path.join(relative_path, file_name)
|
||||
|
||||
workspace = self._get_task_agent_file_workspace(task_id, self.agent_manager)
|
||||
workspace = self._get_task_agent_file_workspace(task_id)
|
||||
await workspace.write_file(file_path, data)
|
||||
|
||||
artifact = await self.db.create_artifact(
|
||||
@@ -400,12 +423,12 @@ class AgentProtocolServer:
|
||||
Download a task artifact by ID.
|
||||
"""
|
||||
try:
|
||||
workspace = self._get_task_agent_file_workspace(task_id)
|
||||
artifact = await self.db.get_artifact(artifact_id)
|
||||
if artifact.file_name not in artifact.relative_path:
|
||||
file_path = os.path.join(artifact.relative_path, artifact.file_name)
|
||||
else:
|
||||
file_path = artifact.relative_path
|
||||
workspace = self._get_task_agent_file_workspace(task_id, self.agent_manager)
|
||||
retrieved_artifact = workspace.read_file(file_path, binary=True)
|
||||
except NotFoundError:
|
||||
raise
|
||||
@@ -420,28 +443,9 @@ class AgentProtocolServer:
|
||||
},
|
||||
)
|
||||
|
||||
def _get_task_agent_file_workspace(
|
||||
self,
|
||||
task_id: str | int,
|
||||
agent_manager: AgentManager,
|
||||
) -> FileWorkspace:
|
||||
use_local_ws = (
|
||||
self.app_config.workspace_backend == FileWorkspaceBackendName.LOCAL
|
||||
)
|
||||
def _get_task_agent_file_workspace(self, task_id: str | int) -> FileStorage:
|
||||
agent_id = task_agent_id(task_id)
|
||||
workspace = get_workspace(
|
||||
backend=self.app_config.workspace_backend,
|
||||
id=agent_id if not use_local_ws else "",
|
||||
root_path=agent_manager.get_agent_dir(
|
||||
agent_id=agent_id,
|
||||
must_exist=True,
|
||||
)
|
||||
/ "workspace"
|
||||
if use_local_ws
|
||||
else None,
|
||||
)
|
||||
workspace.initialize()
|
||||
return workspace
|
||||
return self.file_storage.clone_with_subroot(f"agents/{agent_id}/workspace")
|
||||
|
||||
def _get_task_llm_provider(
|
||||
self, task: Task, step_id: str = ""
|
||||
@@ -449,9 +453,7 @@ class AgentProtocolServer:
|
||||
"""
|
||||
Configures the LLM provider with headers to link outgoing requests to the task.
|
||||
"""
|
||||
task_llm_budget = self._task_budgets.get(
|
||||
task.task_id, self.llm_provider.default_settings.budget.copy(deep=True)
|
||||
)
|
||||
task_llm_budget = self._task_budgets[task.task_id]
|
||||
|
||||
task_llm_provider_config = self.llm_provider._configuration.copy(deep=True)
|
||||
_extra_request_headers = task_llm_provider_config.extra_request_headers
|
||||
@@ -461,20 +463,18 @@ class AgentProtocolServer:
|
||||
if task.additional_input and (user_id := task.additional_input.get("user_id")):
|
||||
_extra_request_headers["AutoGPT-UserID"] = user_id
|
||||
|
||||
task_llm_provider = None
|
||||
if isinstance(self.llm_provider, OpenAIProvider):
|
||||
settings = self.llm_provider._settings.copy()
|
||||
settings.budget = task_llm_budget
|
||||
settings.configuration = task_llm_provider_config # type: ignore
|
||||
task_llm_provider = OpenAIProvider(
|
||||
settings=settings,
|
||||
logger=logger.getChild(f"Task-{task.task_id}_OpenAIProvider"),
|
||||
)
|
||||
settings = self.llm_provider._settings.copy()
|
||||
settings.budget = task_llm_budget
|
||||
settings.configuration = task_llm_provider_config
|
||||
task_llm_provider = self.llm_provider.__class__(
|
||||
settings=settings,
|
||||
logger=logger.getChild(
|
||||
f"Task-{task.task_id}_{self.llm_provider.__class__.__name__}"
|
||||
),
|
||||
)
|
||||
self._task_budgets[task.task_id] = task_llm_provider._budget # type: ignore
|
||||
|
||||
if task_llm_provider and task_llm_provider._budget:
|
||||
self._task_budgets[task.task_id] = task_llm_provider._budget
|
||||
|
||||
return task_llm_provider or self.llm_provider
|
||||
return task_llm_provider
|
||||
|
||||
|
||||
def task_agent_id(task_id: str | int) -> str:
|
||||
|
||||
@@ -7,10 +7,14 @@ import click
|
||||
|
||||
from autogpt.logs.config import LogFormatName
|
||||
|
||||
from .telemetry import setup_telemetry
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context):
|
||||
setup_telemetry()
|
||||
|
||||
# Invoke `run` by default
|
||||
if ctx.invoked_subcommand is None:
|
||||
ctx.invoke(run)
|
||||
|
||||
@@ -3,37 +3,28 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
import click
|
||||
from colorama import Back, Fore, Style
|
||||
|
||||
from autogpt import utils
|
||||
from autogpt.config import Config
|
||||
from autogpt.config.config import GPT_3_MODEL, GPT_4_MODEL
|
||||
from autogpt.llm.api_manager import ApiManager
|
||||
from autogpt.logs.config import LogFormatName
|
||||
from autogpt.core.resource.model_providers import ModelName, MultiProvider
|
||||
from autogpt.logs.helpers import request_user_double_check
|
||||
from autogpt.memory.vector import get_supported_memory_backends
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.core.resource.model_providers.openai import OpenAICredentials
|
||||
from autogpt.utils import utils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def apply_overrides_to_config(
|
||||
async def apply_overrides_to_config(
|
||||
config: Config,
|
||||
continuous: bool = False,
|
||||
continuous_limit: Optional[int] = None,
|
||||
ai_settings_file: Optional[Path] = None,
|
||||
prompt_settings_file: Optional[Path] = None,
|
||||
skip_reprompt: bool = False,
|
||||
speak: bool = False,
|
||||
debug: bool = False,
|
||||
log_level: Optional[str] = None,
|
||||
log_format: Optional[str] = None,
|
||||
log_file_format: Optional[str] = None,
|
||||
gpt3only: bool = False,
|
||||
gpt4only: bool = False,
|
||||
memory_type: Optional[str] = None,
|
||||
@@ -63,19 +54,6 @@ def apply_overrides_to_config(
|
||||
skips_news (bool): Whether to suppress the output of latest news on startup.
|
||||
"""
|
||||
config.continuous_mode = False
|
||||
config.tts_config.speak_mode = False
|
||||
|
||||
# Set log level
|
||||
if debug:
|
||||
config.logging.level = logging.DEBUG
|
||||
elif log_level and type(_level := logging.getLevelName(log_level.upper())) is int:
|
||||
config.logging.level = _level
|
||||
|
||||
# Set log format
|
||||
if log_format and log_format in LogFormatName._value2member_map_:
|
||||
config.logging.log_format = LogFormatName(log_format)
|
||||
if log_file_format and log_file_format in LogFormatName._value2member_map_:
|
||||
config.logging.log_file_format = LogFormatName(log_file_format)
|
||||
|
||||
if continuous:
|
||||
logger.warning(
|
||||
@@ -92,9 +70,6 @@ def apply_overrides_to_config(
|
||||
if continuous_limit and not continuous:
|
||||
raise click.UsageError("--continuous-limit can only be used with --continuous")
|
||||
|
||||
if speak:
|
||||
config.tts_config.speak_mode = True
|
||||
|
||||
# Set the default LLM models
|
||||
if gpt3only:
|
||||
# --gpt3only should always use gpt-3.5-turbo, despite user's FAST_LLM config
|
||||
@@ -102,23 +77,14 @@ def apply_overrides_to_config(
|
||||
config.smart_llm = GPT_3_MODEL
|
||||
elif (
|
||||
gpt4only
|
||||
and check_model(
|
||||
GPT_4_MODEL,
|
||||
model_type="smart_llm",
|
||||
api_credentials=config.openai_credentials,
|
||||
)
|
||||
== GPT_4_MODEL
|
||||
and (await check_model(GPT_4_MODEL, model_type="smart_llm")) == GPT_4_MODEL
|
||||
):
|
||||
# --gpt4only should always use gpt-4, despite user's SMART_LLM config
|
||||
config.fast_llm = GPT_4_MODEL
|
||||
config.smart_llm = GPT_4_MODEL
|
||||
else:
|
||||
config.fast_llm = check_model(
|
||||
config.fast_llm, "fast_llm", api_credentials=config.openai_credentials
|
||||
)
|
||||
config.smart_llm = check_model(
|
||||
config.smart_llm, "smart_llm", api_credentials=config.openai_credentials
|
||||
)
|
||||
config.fast_llm = await check_model(config.fast_llm, "fast_llm")
|
||||
config.smart_llm = await check_model(config.smart_llm, "smart_llm")
|
||||
|
||||
if memory_type:
|
||||
supported_memory = get_supported_memory_backends()
|
||||
@@ -183,19 +149,17 @@ def apply_overrides_to_config(
|
||||
config.skip_news = True
|
||||
|
||||
|
||||
def check_model(
|
||||
model_name: str,
|
||||
model_type: Literal["smart_llm", "fast_llm"],
|
||||
api_credentials: OpenAICredentials,
|
||||
) -> str:
|
||||
async def check_model(
|
||||
model_name: ModelName, model_type: Literal["smart_llm", "fast_llm"]
|
||||
) -> ModelName:
|
||||
"""Check if model is available for use. If not, return gpt-3.5-turbo."""
|
||||
api_manager = ApiManager()
|
||||
models = api_manager.get_models(api_credentials)
|
||||
multi_provider = MultiProvider()
|
||||
models = await multi_provider.get_available_models()
|
||||
|
||||
if any(model_name == m.id for m in models):
|
||||
if any(model_name == m.name for m in models):
|
||||
return model_name
|
||||
|
||||
logger.warning(
|
||||
f"You don't have access to {model_name}. Setting {model_type} to gpt-3.5-turbo."
|
||||
f"You don't have access to {model_name}. Setting {model_type} to {GPT_3_MODEL}."
|
||||
)
|
||||
return "gpt-3.5-turbo"
|
||||
return GPT_3_MODEL
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
The application entry point. Can be invoked by a CLI or any other front end application.
|
||||
"""
|
||||
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
@@ -17,12 +18,16 @@ from forge.sdk.db import AgentDB
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.base import BaseAgentActionProposal
|
||||
|
||||
from autogpt.agent_factory.configurators import configure_agent_with_state, create_agent
|
||||
from autogpt.agent_factory.profile_generator import generate_agent_profile_for_task
|
||||
from autogpt.agent_manager import AgentManager
|
||||
from autogpt.agents import AgentThoughts, CommandArgs, CommandName
|
||||
from autogpt.agents.utils.exceptions import AgentTerminated, InvalidAgentResponseError
|
||||
from autogpt.agents.prompt_strategies.one_shot import AssistantThoughts
|
||||
from autogpt.commands.execute_code import (
|
||||
is_docker_available,
|
||||
we_are_running_in_a_docker_container,
|
||||
)
|
||||
from autogpt.config import (
|
||||
AIDirectives,
|
||||
AIProfile,
|
||||
@@ -30,12 +35,15 @@ from autogpt.config import (
|
||||
ConfigBuilder,
|
||||
assert_config_has_openai_api_key,
|
||||
)
|
||||
from autogpt.core.resource.model_providers.openai import OpenAIProvider
|
||||
from autogpt.core.resource.model_providers import MultiProvider
|
||||
from autogpt.core.runner.client_lib.utils import coroutine
|
||||
from autogpt.logs.config import configure_chat_plugins, configure_logging
|
||||
from autogpt.file_storage import FileStorageBackendName, get_storage
|
||||
from autogpt.logs.config import configure_logging
|
||||
from autogpt.logs.helpers import print_attribute, speak
|
||||
from autogpt.plugins import scan_plugins
|
||||
from scripts.install_plugin_deps import install_plugin_dependencies
|
||||
from autogpt.models.action_history import ActionInterruptedByHuman
|
||||
from autogpt.models.utils import ModelWithSummary
|
||||
from autogpt.utils.exceptions import AgentTerminated, InvalidAgentResponseError
|
||||
from autogpt.utils.utils import DEFAULT_FINISH_COMMAND
|
||||
|
||||
from .configurator import apply_overrides_to_config
|
||||
from .setup import apply_overrides_to_ai_settings, interactively_revise_ai_settings
|
||||
@@ -76,23 +84,38 @@ async def run_auto_gpt(
|
||||
best_practices: Optional[list[str]] = None,
|
||||
override_directives: bool = False,
|
||||
):
|
||||
# Set up configuration
|
||||
config = ConfigBuilder.build_config_from_env()
|
||||
# Storage
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
# Set up logging module
|
||||
if speak:
|
||||
config.tts_config.speak_mode = True
|
||||
configure_logging(
|
||||
debug=debug,
|
||||
level=log_level,
|
||||
log_format=log_format,
|
||||
log_file_format=log_file_format,
|
||||
config=config.logging,
|
||||
tts_config=config.tts_config,
|
||||
)
|
||||
|
||||
# TODO: fill in llm values here
|
||||
assert_config_has_openai_api_key(config)
|
||||
|
||||
apply_overrides_to_config(
|
||||
await apply_overrides_to_config(
|
||||
config=config,
|
||||
continuous=continuous,
|
||||
continuous_limit=continuous_limit,
|
||||
ai_settings_file=ai_settings,
|
||||
prompt_settings_file=prompt_settings,
|
||||
skip_reprompt=skip_reprompt,
|
||||
speak=speak,
|
||||
debug=debug,
|
||||
log_level=log_level,
|
||||
log_format=log_format,
|
||||
log_file_format=log_file_format,
|
||||
gpt3only=gpt3only,
|
||||
gpt4only=gpt4only,
|
||||
browser_name=browser_name,
|
||||
@@ -100,13 +123,7 @@ async def run_auto_gpt(
|
||||
skip_news=skip_news,
|
||||
)
|
||||
|
||||
# Set up logging module
|
||||
configure_logging(
|
||||
**config.logging.dict(),
|
||||
tts_config=config.tts_config,
|
||||
)
|
||||
|
||||
llm_provider = _configure_openai_provider(config)
|
||||
llm_provider = _configure_llm_provider(config)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -140,15 +157,17 @@ async def run_auto_gpt(
|
||||
print_attribute("Using Prompt Settings File", prompt_settings)
|
||||
if config.allow_downloads:
|
||||
print_attribute("Native Downloading", "ENABLED")
|
||||
|
||||
if install_plugin_deps:
|
||||
install_plugin_dependencies()
|
||||
|
||||
config.plugins = scan_plugins(config)
|
||||
configure_chat_plugins(config)
|
||||
if we_are_running_in_a_docker_container() or is_docker_available():
|
||||
print_attribute("Code Execution", "ENABLED")
|
||||
else:
|
||||
print_attribute(
|
||||
"Code Execution",
|
||||
"DISABLED (Docker unavailable)",
|
||||
title_color=Fore.YELLOW,
|
||||
)
|
||||
|
||||
# Let user choose an existing agent to run
|
||||
agent_manager = AgentManager(config.app_data_dir)
|
||||
agent_manager = AgentManager(file_storage)
|
||||
existing_agents = agent_manager.list_agents()
|
||||
load_existing_agent = ""
|
||||
if existing_agents:
|
||||
@@ -156,15 +175,23 @@ async def run_auto_gpt(
|
||||
"Existing agents\n---------------\n"
|
||||
+ "\n".join(f"{i} - {id}" for i, id in enumerate(existing_agents, 1))
|
||||
)
|
||||
load_existing_agent = await clean_input(
|
||||
load_existing_agent = clean_input(
|
||||
config,
|
||||
"Enter the number or name of the agent to run,"
|
||||
" or hit enter to create a new one:",
|
||||
)
|
||||
if re.match(r"^\d+$", load_existing_agent):
|
||||
if re.match(r"^\d+$", load_existing_agent.strip()) and 0 < int(
|
||||
load_existing_agent
|
||||
) <= len(existing_agents):
|
||||
load_existing_agent = existing_agents[int(load_existing_agent) - 1]
|
||||
elif load_existing_agent and load_existing_agent not in existing_agents:
|
||||
raise ValueError(f"Unknown agent '{load_existing_agent}'")
|
||||
|
||||
if load_existing_agent != "" and load_existing_agent not in existing_agents:
|
||||
logger.info(
|
||||
f"Unknown agent '{load_existing_agent}', "
|
||||
f"creating a new one instead.",
|
||||
extra={"color": Fore.YELLOW},
|
||||
)
|
||||
load_existing_agent = ""
|
||||
|
||||
# Either load existing or set up new agent state
|
||||
agent = None
|
||||
@@ -174,21 +201,20 @@ async def run_auto_gpt(
|
||||
# Resume an Existing Agent #
|
||||
############################
|
||||
if load_existing_agent:
|
||||
agent_state = agent_manager.retrieve_state(load_existing_agent)
|
||||
agent_state = None
|
||||
while True:
|
||||
answer = await clean_input(config, "Resume? [Y/n]")
|
||||
if answer.lower() == "y":
|
||||
answer = clean_input(config, "Resume? [Y/n]")
|
||||
if answer == "" or answer.lower() == "y":
|
||||
agent_state = agent_manager.load_agent_state(load_existing_agent)
|
||||
break
|
||||
elif answer.lower() == "n":
|
||||
agent_state = None
|
||||
break
|
||||
else:
|
||||
print("Please respond with 'y' or 'n'")
|
||||
|
||||
if agent_state:
|
||||
agent = configure_agent_with_state(
|
||||
state=agent_state,
|
||||
app_config=config,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
apply_overrides_to_ai_settings(
|
||||
@@ -202,6 +228,21 @@ async def run_auto_gpt(
|
||||
replace_directives=override_directives,
|
||||
)
|
||||
|
||||
if (
|
||||
(current_episode := agent.event_history.current_episode)
|
||||
and current_episode.action.use_tool.name == DEFAULT_FINISH_COMMAND
|
||||
and not current_episode.result
|
||||
):
|
||||
# Agent was resumed after `finish` -> rewrite result of `finish` action
|
||||
finish_reason = current_episode.action.use_tool.arguments["reason"]
|
||||
print(f"Agent previously self-terminated; reason: '{finish_reason}'")
|
||||
new_assignment = clean_input(
|
||||
config, "Please give a follow-up question or assignment:"
|
||||
)
|
||||
agent.event_history.register_result(
|
||||
ActionInterruptedByHuman(feedback=new_assignment)
|
||||
)
|
||||
|
||||
# If any of these are specified as arguments,
|
||||
# assume the user doesn't want to revise them
|
||||
if not any(
|
||||
@@ -225,11 +266,14 @@ async def run_auto_gpt(
|
||||
# Set up a new Agent #
|
||||
######################
|
||||
if not agent:
|
||||
task = await clean_input(
|
||||
config,
|
||||
"Enter the task that you want AutoGPT to execute,"
|
||||
" with as much detail as possible:",
|
||||
)
|
||||
task = ""
|
||||
while task.strip() == "":
|
||||
task = clean_input(
|
||||
config,
|
||||
"Enter the task that you want AutoGPT to execute,"
|
||||
" with as much detail as possible:",
|
||||
)
|
||||
|
||||
base_ai_directives = AIDirectives.from_file(config.prompt_settings_file)
|
||||
|
||||
ai_profile, task_oriented_ai_directives = await generate_agent_profile_for_task(
|
||||
@@ -269,19 +313,22 @@ async def run_auto_gpt(
|
||||
logger.info("AI config overrides specified through CLI; skipping revision")
|
||||
|
||||
agent = create_agent(
|
||||
agent_id=agent_manager.generate_id(ai_profile.ai_name),
|
||||
task=task,
|
||||
ai_profile=ai_profile,
|
||||
directives=ai_directives,
|
||||
app_config=config,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
agent.attach_fs(agent_manager.get_agent_dir(agent.state.agent_id))
|
||||
|
||||
if not agent.config.allow_fs_access:
|
||||
file_manager = agent.file_manager
|
||||
|
||||
if file_manager and not agent.config.allow_fs_access:
|
||||
logger.info(
|
||||
f"{Fore.YELLOW}"
|
||||
"NOTE: All files/directories created by this agent can be found "
|
||||
f"inside its workspace at:{Fore.RESET} {agent.workspace.root}",
|
||||
f"inside its workspace at:{Fore.RESET} {file_manager.workspace.root}",
|
||||
extra={"preserve_color": True},
|
||||
)
|
||||
|
||||
@@ -295,23 +342,15 @@ async def run_auto_gpt(
|
||||
logger.info(f"Saving state of {agent_id}...")
|
||||
|
||||
# Allow user to Save As other ID
|
||||
save_as_id = (
|
||||
await clean_input(
|
||||
config,
|
||||
f"Press enter to save as '{agent_id}',"
|
||||
" or enter a different ID to save to:",
|
||||
)
|
||||
or agent_id
|
||||
save_as_id = clean_input(
|
||||
config,
|
||||
f"Press enter to save as '{agent_id}',"
|
||||
" or enter a different ID to save to:",
|
||||
)
|
||||
# TODO: allow many-to-one relations of agents and workspaces
|
||||
await agent.file_manager.save_state(
|
||||
save_as_id.strip() if not save_as_id.isspace() else None
|
||||
)
|
||||
if save_as_id and save_as_id != agent_id:
|
||||
agent.set_id(
|
||||
new_id=save_as_id,
|
||||
new_agent_dir=agent_manager.get_agent_dir(save_as_id),
|
||||
)
|
||||
# TODO: clone workspace if user wants that
|
||||
# TODO: ... OR allow many-to-one relations of agents and workspaces
|
||||
|
||||
agent.state.save_to_json_file(agent.file_manager.state_file_path)
|
||||
|
||||
|
||||
@coroutine
|
||||
@@ -330,35 +369,37 @@ async def run_auto_gpt_server(
|
||||
from .agent_protocol_server import AgentProtocolServer
|
||||
|
||||
config = ConfigBuilder.build_config_from_env()
|
||||
# Storage
|
||||
local = config.file_storage_backend == FileStorageBackendName.LOCAL
|
||||
restrict_to_root = not local or config.restrict_to_workspace
|
||||
file_storage = get_storage(
|
||||
config.file_storage_backend, root_path="data", restrict_to_root=restrict_to_root
|
||||
)
|
||||
file_storage.initialize()
|
||||
|
||||
# Set up logging module
|
||||
configure_logging(
|
||||
debug=debug,
|
||||
level=log_level,
|
||||
log_format=log_format,
|
||||
log_file_format=log_file_format,
|
||||
config=config.logging,
|
||||
tts_config=config.tts_config,
|
||||
)
|
||||
|
||||
# TODO: fill in llm values here
|
||||
assert_config_has_openai_api_key(config)
|
||||
|
||||
apply_overrides_to_config(
|
||||
await apply_overrides_to_config(
|
||||
config=config,
|
||||
prompt_settings_file=prompt_settings,
|
||||
debug=debug,
|
||||
log_level=log_level,
|
||||
log_format=log_format,
|
||||
log_file_format=log_file_format,
|
||||
gpt3only=gpt3only,
|
||||
gpt4only=gpt4only,
|
||||
browser_name=browser_name,
|
||||
allow_downloads=allow_downloads,
|
||||
)
|
||||
|
||||
# Set up logging module
|
||||
configure_logging(
|
||||
**config.logging.dict(),
|
||||
tts_config=config.tts_config,
|
||||
)
|
||||
|
||||
llm_provider = _configure_openai_provider(config)
|
||||
|
||||
if install_plugin_deps:
|
||||
install_plugin_dependencies()
|
||||
|
||||
config.plugins = scan_plugins(config)
|
||||
llm_provider = _configure_llm_provider(config)
|
||||
|
||||
# Set up & start server
|
||||
database = AgentDB(
|
||||
@@ -367,7 +408,10 @@ async def run_auto_gpt_server(
|
||||
)
|
||||
port: int = int(os.getenv("AP_SERVER_PORT", default=8000))
|
||||
server = AgentProtocolServer(
|
||||
app_config=config, database=database, llm_provider=llm_provider
|
||||
app_config=config,
|
||||
database=database,
|
||||
file_storage=file_storage,
|
||||
llm_provider=llm_provider,
|
||||
)
|
||||
await server.start(port=port)
|
||||
|
||||
@@ -377,24 +421,12 @@ async def run_auto_gpt_server(
|
||||
)
|
||||
|
||||
|
||||
def _configure_openai_provider(config: Config) -> OpenAIProvider:
|
||||
"""Create a configured OpenAIProvider object.
|
||||
|
||||
Args:
|
||||
config: The program's configuration.
|
||||
|
||||
Returns:
|
||||
A configured OpenAIProvider object.
|
||||
"""
|
||||
if config.openai_credentials is None:
|
||||
raise RuntimeError("OpenAI key is not configured")
|
||||
|
||||
openai_settings = OpenAIProvider.default_settings.copy(deep=True)
|
||||
openai_settings.credentials = config.openai_credentials
|
||||
return OpenAIProvider(
|
||||
settings=openai_settings,
|
||||
logger=logging.getLogger("OpenAIProvider"),
|
||||
)
|
||||
def _configure_llm_provider(config: Config) -> MultiProvider:
|
||||
multi_provider = MultiProvider()
|
||||
for model in [config.smart_llm, config.fast_llm]:
|
||||
# Ensure model providers for configured LLMs are available
|
||||
multi_provider.get_model_provider(model)
|
||||
return multi_provider
|
||||
|
||||
|
||||
def _get_cycle_budget(continuous_mode: bool, continuous_limit: int) -> int | float:
|
||||
@@ -488,11 +520,7 @@ async def run_interaction_loop(
|
||||
# Have the agent determine the next action to take.
|
||||
with spinner:
|
||||
try:
|
||||
(
|
||||
command_name,
|
||||
command_args,
|
||||
assistant_reply_dict,
|
||||
) = await agent.propose_action()
|
||||
action_proposal = await agent.propose_action()
|
||||
except InvalidAgentResponseError as e:
|
||||
logger.warning(f"The agent's thoughts could not be parsed: {e}")
|
||||
consecutive_failures += 1
|
||||
@@ -515,9 +543,7 @@ async def run_interaction_loop(
|
||||
# Print the assistant's thoughts and the next command to the user.
|
||||
update_user(
|
||||
ai_profile,
|
||||
command_name,
|
||||
command_args,
|
||||
assistant_reply_dict,
|
||||
action_proposal,
|
||||
speak_mode=legacy_config.tts_config.speak_mode,
|
||||
)
|
||||
|
||||
@@ -526,12 +552,12 @@ async def run_interaction_loop(
|
||||
##################
|
||||
handle_stop_signal()
|
||||
if cycles_remaining == 1: # Last cycle
|
||||
user_feedback, user_input, new_cycles_remaining = await get_user_feedback(
|
||||
feedback_type, feedback, new_cycles_remaining = await get_user_feedback(
|
||||
legacy_config,
|
||||
ai_profile,
|
||||
)
|
||||
|
||||
if user_feedback == UserFeedback.AUTHORIZE:
|
||||
if feedback_type == UserFeedback.AUTHORIZE:
|
||||
if new_cycles_remaining is not None:
|
||||
# Case 1: User is altering the cycle budget.
|
||||
if cycle_budget > 1:
|
||||
@@ -555,13 +581,13 @@ async def run_interaction_loop(
|
||||
"-=-=-=-=-=-=-= COMMAND AUTHORISED BY USER -=-=-=-=-=-=-=",
|
||||
extra={"color": Fore.MAGENTA},
|
||||
)
|
||||
elif user_feedback == UserFeedback.EXIT:
|
||||
elif feedback_type == UserFeedback.EXIT:
|
||||
logger.warning("Exiting...")
|
||||
exit()
|
||||
else: # user_feedback == UserFeedback.TEXT
|
||||
command_name = "human_feedback"
|
||||
pass
|
||||
else:
|
||||
user_input = ""
|
||||
feedback = ""
|
||||
# First log new-line so user can differentiate sections better in console
|
||||
print()
|
||||
if cycles_remaining != math.inf:
|
||||
@@ -576,33 +602,31 @@ async def run_interaction_loop(
|
||||
# Decrement the cycle counter first to reduce the likelihood of a SIGINT
|
||||
# happening during command execution, setting the cycles remaining to 1,
|
||||
# and then having the decrement set it to 0, exiting the application.
|
||||
if command_name != "human_feedback":
|
||||
if not feedback:
|
||||
cycles_remaining -= 1
|
||||
|
||||
if not command_name:
|
||||
if not action_proposal.use_tool:
|
||||
continue
|
||||
|
||||
handle_stop_signal()
|
||||
|
||||
if command_name:
|
||||
result = await agent.execute(command_name, command_args, user_input)
|
||||
if not feedback:
|
||||
result = await agent.execute(action_proposal)
|
||||
else:
|
||||
result = await agent.do_not_execute(action_proposal, feedback)
|
||||
|
||||
if result.status == "success":
|
||||
logger.info(
|
||||
result, extra={"title": "SYSTEM:", "title_color": Fore.YELLOW}
|
||||
)
|
||||
elif result.status == "error":
|
||||
logger.warning(
|
||||
f"Command {command_name} returned an error: "
|
||||
f"{result.error or result.reason}"
|
||||
)
|
||||
if result.status == "success":
|
||||
logger.info(result, extra={"title": "SYSTEM:", "title_color": Fore.YELLOW})
|
||||
elif result.status == "error":
|
||||
logger.warning(
|
||||
f"Command {action_proposal.use_tool.name} returned an error: "
|
||||
f"{result.error or result.reason}"
|
||||
)
|
||||
|
||||
|
||||
def update_user(
|
||||
ai_profile: AIProfile,
|
||||
command_name: CommandName,
|
||||
command_args: CommandArgs,
|
||||
assistant_reply_dict: AgentThoughts,
|
||||
action_proposal: "BaseAgentActionProposal",
|
||||
speak_mode: bool = False,
|
||||
) -> None:
|
||||
"""Prints the assistant's thoughts and the next command to the user.
|
||||
@@ -618,18 +642,19 @@ def update_user(
|
||||
|
||||
print_assistant_thoughts(
|
||||
ai_name=ai_profile.ai_name,
|
||||
assistant_reply_json_valid=assistant_reply_dict,
|
||||
thoughts=action_proposal.thoughts,
|
||||
speak_mode=speak_mode,
|
||||
)
|
||||
|
||||
if speak_mode:
|
||||
speak(f"I want to execute {command_name}")
|
||||
speak(f"I want to execute {action_proposal.use_tool.name}")
|
||||
|
||||
# First log new-line so user can differentiate sections better in console
|
||||
print()
|
||||
safe_tool_name = remove_ansi_escape(action_proposal.use_tool.name)
|
||||
logger.info(
|
||||
f"COMMAND = {Fore.CYAN}{remove_ansi_escape(command_name)}{Style.RESET_ALL} "
|
||||
f"ARGUMENTS = {Fore.CYAN}{command_args}{Style.RESET_ALL}",
|
||||
f"COMMAND = {Fore.CYAN}{safe_tool_name}{Style.RESET_ALL} "
|
||||
f"ARGUMENTS = {Fore.CYAN}{action_proposal.use_tool.arguments}{Style.RESET_ALL}",
|
||||
extra={
|
||||
"title": "NEXT ACTION:",
|
||||
"title_color": Fore.CYAN,
|
||||
@@ -670,12 +695,7 @@ async def get_user_feedback(
|
||||
|
||||
while user_feedback is None:
|
||||
# Get input from user
|
||||
if config.chat_messages_enabled:
|
||||
console_input = await clean_input(config, "Waiting for your response...")
|
||||
else:
|
||||
console_input = await clean_input(
|
||||
config, Fore.MAGENTA + "Input:" + Style.RESET_ALL
|
||||
)
|
||||
console_input = clean_input(config, Fore.MAGENTA + "Input:" + Style.RESET_ALL)
|
||||
|
||||
# Parse user input
|
||||
if console_input.lower().strip() == config.authorise_key:
|
||||
@@ -703,56 +723,59 @@ async def get_user_feedback(
|
||||
|
||||
def print_assistant_thoughts(
|
||||
ai_name: str,
|
||||
assistant_reply_json_valid: dict,
|
||||
thoughts: str | ModelWithSummary | AssistantThoughts,
|
||||
speak_mode: bool = False,
|
||||
) -> None:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
assistant_thoughts_reasoning = None
|
||||
assistant_thoughts_plan = None
|
||||
assistant_thoughts_speak = None
|
||||
assistant_thoughts_criticism = None
|
||||
|
||||
assistant_thoughts = assistant_reply_json_valid.get("thoughts", {})
|
||||
assistant_thoughts_text = remove_ansi_escape(assistant_thoughts.get("text", ""))
|
||||
if assistant_thoughts:
|
||||
assistant_thoughts_reasoning = remove_ansi_escape(
|
||||
assistant_thoughts.get("reasoning", "")
|
||||
)
|
||||
assistant_thoughts_plan = remove_ansi_escape(assistant_thoughts.get("plan", ""))
|
||||
assistant_thoughts_criticism = remove_ansi_escape(
|
||||
assistant_thoughts.get("self_criticism", "")
|
||||
)
|
||||
assistant_thoughts_speak = remove_ansi_escape(
|
||||
assistant_thoughts.get("speak", "")
|
||||
)
|
||||
print_attribute(
|
||||
f"{ai_name.upper()} THOUGHTS", assistant_thoughts_text, title_color=Fore.YELLOW
|
||||
thoughts_text = remove_ansi_escape(
|
||||
thoughts.text
|
||||
if isinstance(thoughts, AssistantThoughts)
|
||||
else thoughts.summary()
|
||||
if isinstance(thoughts, ModelWithSummary)
|
||||
else thoughts
|
||||
)
|
||||
print_attribute("REASONING", assistant_thoughts_reasoning, title_color=Fore.YELLOW)
|
||||
if assistant_thoughts_plan:
|
||||
print_attribute("PLAN", "", title_color=Fore.YELLOW)
|
||||
# If it's a list, join it into a string
|
||||
if isinstance(assistant_thoughts_plan, list):
|
||||
assistant_thoughts_plan = "\n".join(assistant_thoughts_plan)
|
||||
elif isinstance(assistant_thoughts_plan, dict):
|
||||
assistant_thoughts_plan = str(assistant_thoughts_plan)
|
||||
|
||||
# Split the input_string using the newline character and dashes
|
||||
lines = assistant_thoughts_plan.split("\n")
|
||||
for line in lines:
|
||||
line = line.lstrip("- ")
|
||||
logger.info(line.strip(), extra={"title": "- ", "title_color": Fore.GREEN})
|
||||
print_attribute(
|
||||
"CRITICISM", f"{assistant_thoughts_criticism}", title_color=Fore.YELLOW
|
||||
f"{ai_name.upper()} THOUGHTS", thoughts_text, title_color=Fore.YELLOW
|
||||
)
|
||||
|
||||
# Speak the assistant's thoughts
|
||||
if assistant_thoughts_speak:
|
||||
if speak_mode:
|
||||
speak(assistant_thoughts_speak)
|
||||
else:
|
||||
print_attribute("SPEAK", assistant_thoughts_speak, title_color=Fore.YELLOW)
|
||||
if isinstance(thoughts, AssistantThoughts):
|
||||
print_attribute(
|
||||
"REASONING", remove_ansi_escape(thoughts.reasoning), title_color=Fore.YELLOW
|
||||
)
|
||||
if assistant_thoughts_plan := remove_ansi_escape(
|
||||
"\n".join(f"- {p}" for p in thoughts.plan)
|
||||
):
|
||||
print_attribute("PLAN", "", title_color=Fore.YELLOW)
|
||||
# If it's a list, join it into a string
|
||||
if isinstance(assistant_thoughts_plan, list):
|
||||
assistant_thoughts_plan = "\n".join(assistant_thoughts_plan)
|
||||
elif isinstance(assistant_thoughts_plan, dict):
|
||||
assistant_thoughts_plan = str(assistant_thoughts_plan)
|
||||
|
||||
# Split the input_string using the newline character and dashes
|
||||
lines = assistant_thoughts_plan.split("\n")
|
||||
for line in lines:
|
||||
line = line.lstrip("- ")
|
||||
logger.info(
|
||||
line.strip(), extra={"title": "- ", "title_color": Fore.GREEN}
|
||||
)
|
||||
print_attribute(
|
||||
"CRITICISM",
|
||||
remove_ansi_escape(thoughts.self_criticism),
|
||||
title_color=Fore.YELLOW,
|
||||
)
|
||||
|
||||
# Speak the assistant's thoughts
|
||||
if assistant_thoughts_speak := remove_ansi_escape(thoughts.speak):
|
||||
if speak_mode:
|
||||
speak(assistant_thoughts_speak)
|
||||
else:
|
||||
print_attribute(
|
||||
"SPEAK", assistant_thoughts_speak, title_color=Fore.YELLOW
|
||||
)
|
||||
else:
|
||||
speak(thoughts_text)
|
||||
|
||||
|
||||
def remove_ansi_escape(s: str) -> str:
|
||||
|
||||
@@ -69,44 +69,48 @@ async def interactively_revise_ai_settings(
|
||||
)
|
||||
|
||||
if (
|
||||
await clean_input(app_config, "Continue with these settings? [Y/n]")
|
||||
clean_input(app_config, "Continue with these settings? [Y/n]").lower()
|
||||
or app_config.authorise_key
|
||||
) == app_config.authorise_key:
|
||||
break
|
||||
|
||||
# Ask for revised ai_profile
|
||||
ai_profile.ai_name = (
|
||||
await clean_input(
|
||||
app_config, "Enter AI name (or press enter to keep current):"
|
||||
)
|
||||
clean_input(app_config, "Enter AI name (or press enter to keep current):")
|
||||
or ai_profile.ai_name
|
||||
)
|
||||
ai_profile.ai_role = (
|
||||
await clean_input(
|
||||
clean_input(
|
||||
app_config, "Enter new AI role (or press enter to keep current):"
|
||||
)
|
||||
or ai_profile.ai_role
|
||||
)
|
||||
|
||||
# Revise constraints
|
||||
for i, constraint in enumerate(directives.constraints):
|
||||
i = 0
|
||||
while i < len(directives.constraints):
|
||||
constraint = directives.constraints[i]
|
||||
print_attribute(f"Constraint {i+1}:", f'"{constraint}"')
|
||||
new_constraint = (
|
||||
await clean_input(
|
||||
clean_input(
|
||||
app_config,
|
||||
f"Enter new constraint {i+1}"
|
||||
" (press enter to keep current, or '-' to remove):",
|
||||
)
|
||||
or constraint
|
||||
)
|
||||
|
||||
if new_constraint == "-":
|
||||
directives.constraints.remove(constraint)
|
||||
continue
|
||||
elif new_constraint:
|
||||
directives.constraints[i] = new_constraint
|
||||
|
||||
i += 1
|
||||
|
||||
# Add new constraints
|
||||
while True:
|
||||
new_constraint = await clean_input(
|
||||
new_constraint = clean_input(
|
||||
app_config,
|
||||
"Press enter to finish, or enter a constraint to add:",
|
||||
)
|
||||
@@ -115,10 +119,12 @@ async def interactively_revise_ai_settings(
|
||||
directives.constraints.append(new_constraint)
|
||||
|
||||
# Revise resources
|
||||
for i, resource in enumerate(directives.resources):
|
||||
i = 0
|
||||
while i < len(directives.resources):
|
||||
resource = directives.resources[i]
|
||||
print_attribute(f"Resource {i+1}:", f'"{resource}"')
|
||||
new_resource = (
|
||||
await clean_input(
|
||||
clean_input(
|
||||
app_config,
|
||||
f"Enter new resource {i+1}"
|
||||
" (press enter to keep current, or '-' to remove):",
|
||||
@@ -127,12 +133,15 @@ async def interactively_revise_ai_settings(
|
||||
)
|
||||
if new_resource == "-":
|
||||
directives.resources.remove(resource)
|
||||
continue
|
||||
elif new_resource:
|
||||
directives.resources[i] = new_resource
|
||||
|
||||
i += 1
|
||||
|
||||
# Add new resources
|
||||
while True:
|
||||
new_resource = await clean_input(
|
||||
new_resource = clean_input(
|
||||
app_config,
|
||||
"Press enter to finish, or enter a resource to add:",
|
||||
)
|
||||
@@ -141,10 +150,12 @@ async def interactively_revise_ai_settings(
|
||||
directives.resources.append(new_resource)
|
||||
|
||||
# Revise best practices
|
||||
for i, best_practice in enumerate(directives.best_practices):
|
||||
i = 0
|
||||
while i < len(directives.best_practices):
|
||||
best_practice = directives.best_practices[i]
|
||||
print_attribute(f"Best Practice {i+1}:", f'"{best_practice}"')
|
||||
new_best_practice = (
|
||||
await clean_input(
|
||||
clean_input(
|
||||
app_config,
|
||||
f"Enter new best practice {i+1}"
|
||||
" (press enter to keep current, or '-' to remove):",
|
||||
@@ -153,12 +164,15 @@ async def interactively_revise_ai_settings(
|
||||
)
|
||||
if new_best_practice == "-":
|
||||
directives.best_practices.remove(best_practice)
|
||||
continue
|
||||
elif new_best_practice:
|
||||
directives.best_practices[i] = new_best_practice
|
||||
|
||||
i += 1
|
||||
|
||||
# Add new best practices
|
||||
while True:
|
||||
new_best_practice = await clean_input(
|
||||
new_best_practice = clean_input(
|
||||
app_config,
|
||||
"Press enter to finish, or add a best practice to add:",
|
||||
)
|
||||
|
||||
64
autogpts/autogpt/autogpt/app/telemetry.py
Normal file
64
autogpts/autogpt/autogpt/app/telemetry.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
|
||||
import click
|
||||
from colorama import Fore, Style
|
||||
|
||||
from .utils import (
|
||||
env_file_exists,
|
||||
get_git_user_email,
|
||||
set_env_config_value,
|
||||
vcs_state_diverges_from_master,
|
||||
)
|
||||
|
||||
|
||||
def setup_telemetry() -> None:
|
||||
if os.getenv("TELEMETRY_OPT_IN") is None:
|
||||
# If no .env file is present, don't bother asking to enable telemetry,
|
||||
# to prevent repeated asking in non-persistent environments.
|
||||
if not env_file_exists():
|
||||
return
|
||||
|
||||
allow_telemetry = click.prompt(
|
||||
f"""
|
||||
{Style.BRIGHT}❓ Do you want to enable telemetry? ❓{Style.NORMAL}
|
||||
This means AutoGPT will send diagnostic data to the core development team when something
|
||||
goes wrong, and will help us to diagnose and fix problems earlier and faster. It also
|
||||
allows us to collect basic performance data, which helps us find bottlenecks and other
|
||||
things that slow down the application.
|
||||
|
||||
By entering 'yes', you confirm that you have read and agree to our Privacy Policy,
|
||||
which is available here:
|
||||
https://www.notion.so/auto-gpt/Privacy-Policy-ab11c9c20dbd4de1a15dcffe84d77984
|
||||
|
||||
Please enter 'yes' or 'no'""",
|
||||
type=bool,
|
||||
)
|
||||
set_env_config_value("TELEMETRY_OPT_IN", "true" if allow_telemetry else "false")
|
||||
click.echo(
|
||||
f"❤️ Thank you! Telemetry is {Fore.GREEN}enabled{Fore.RESET}."
|
||||
if allow_telemetry
|
||||
else f"👍 Telemetry is {Fore.RED}disabled{Fore.RESET}."
|
||||
)
|
||||
click.echo(
|
||||
"💡 If you ever change your mind, you can change 'TELEMETRY_OPT_IN' in .env"
|
||||
)
|
||||
click.echo()
|
||||
|
||||
if os.getenv("TELEMETRY_OPT_IN", "").lower() == "true":
|
||||
_setup_sentry()
|
||||
|
||||
|
||||
def _setup_sentry() -> None:
|
||||
import sentry_sdk
|
||||
|
||||
sentry_sdk.init(
|
||||
dsn="https://dc266f2f7a2381194d1c0fa36dff67d8@o4505260022104064.ingest.sentry.io/4506739844710400", # noqa
|
||||
enable_tracing=True,
|
||||
environment=os.getenv(
|
||||
"TELEMETRY_ENVIRONMENT",
|
||||
"production" if not vcs_state_diverges_from_master() else "dev",
|
||||
),
|
||||
)
|
||||
|
||||
# Allow Sentry to distinguish between users
|
||||
sentry_sdk.set_user({"email": get_git_user_email(), "ip_address": "{{auto}}"})
|
||||
@@ -1,58 +1,31 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
import requests
|
||||
from colorama import Fore, Style
|
||||
from git import InvalidGitRepositoryError, Repo
|
||||
from prompt_toolkit import ANSI, PromptSession
|
||||
from prompt_toolkit.history import InMemoryHistory
|
||||
|
||||
from autogpt.config import Config
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
session = PromptSession(history=InMemoryHistory())
|
||||
|
||||
|
||||
async def clean_input(config: Config, prompt: str = ""):
|
||||
def clean_input(config: "Config", prompt: str = ""):
|
||||
try:
|
||||
if config.chat_messages_enabled:
|
||||
for plugin in config.plugins:
|
||||
if not hasattr(plugin, "can_handle_user_input"):
|
||||
continue
|
||||
if not plugin.can_handle_user_input(user_input=prompt):
|
||||
continue
|
||||
plugin_response = plugin.user_input(user_input=prompt)
|
||||
if not plugin_response:
|
||||
continue
|
||||
if plugin_response.lower() in [
|
||||
"yes",
|
||||
"yeah",
|
||||
"y",
|
||||
"ok",
|
||||
"okay",
|
||||
"sure",
|
||||
"alright",
|
||||
]:
|
||||
return config.authorise_key
|
||||
elif plugin_response.lower() in [
|
||||
"no",
|
||||
"nope",
|
||||
"n",
|
||||
"negative",
|
||||
]:
|
||||
return config.exit_key
|
||||
return plugin_response
|
||||
|
||||
# ask for input, default when just pressing Enter is y
|
||||
logger.debug("Asking user via keyboard...")
|
||||
|
||||
# handle_sigint must be set to False, so the signal handler in the
|
||||
# autogpt/main.py could be employed properly. This referes to
|
||||
# https://github.com/Significant-Gravitas/AutoGPT/pull/4799/files/3966cdfd694c2a80c0333823c3bc3da090f85ed3#r1264278776
|
||||
answer = await session.prompt_async(ANSI(prompt + " "), handle_sigint=False)
|
||||
return answer
|
||||
return click.prompt(
|
||||
text=prompt, prompt_suffix=" ", default="", show_default=False
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("You interrupted AutoGPT")
|
||||
logger.info("Quitting...")
|
||||
@@ -81,6 +54,58 @@ def get_current_git_branch() -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def vcs_state_diverges_from_master() -> bool:
|
||||
"""
|
||||
Returns whether a git repo is present and contains changes that are not in `master`.
|
||||
"""
|
||||
paths_we_care_about = "autogpts/autogpt/autogpt/**/*.py"
|
||||
try:
|
||||
repo = Repo(search_parent_directories=True)
|
||||
|
||||
# Check for uncommitted changes in the specified path
|
||||
uncommitted_changes = repo.index.diff(None, paths=paths_we_care_about)
|
||||
if uncommitted_changes:
|
||||
return True
|
||||
|
||||
# Find OG AutoGPT remote
|
||||
for remote in repo.remotes:
|
||||
if remote.url.endswith(
|
||||
tuple(
|
||||
# All permutations of old/new repo name and HTTP(S)/Git URLs
|
||||
f"{prefix}{path}"
|
||||
for prefix in ("://github.com/", "git@github.com:")
|
||||
for path in (
|
||||
f"Significant-Gravitas/{n}.git" for n in ("AutoGPT", "Auto-GPT")
|
||||
)
|
||||
)
|
||||
):
|
||||
og_remote = remote
|
||||
break
|
||||
else:
|
||||
# Original AutoGPT remote is not configured: assume local codebase diverges
|
||||
return True
|
||||
|
||||
master_branch = og_remote.refs.master
|
||||
with contextlib.suppress(StopIteration):
|
||||
next(repo.iter_commits(f"HEAD..{master_branch}", paths=paths_we_care_about))
|
||||
# Local repo is one or more commits ahead of OG AutoGPT master branch
|
||||
return True
|
||||
|
||||
# Relevant part of the codebase is on master
|
||||
return False
|
||||
except InvalidGitRepositoryError:
|
||||
# No git repo present: assume codebase is a clean download
|
||||
return False
|
||||
|
||||
|
||||
def get_git_user_email() -> str:
|
||||
try:
|
||||
repo = Repo(search_parent_directories=True)
|
||||
return repo.config_reader().get_value("user", "email", default="")
|
||||
except InvalidGitRepositoryError:
|
||||
return ""
|
||||
|
||||
|
||||
def get_latest_bulletin() -> tuple[str, bool]:
|
||||
exists = os.path.exists("data/CURRENT_BULLETIN.md")
|
||||
current_bulletin = ""
|
||||
@@ -149,7 +174,7 @@ By using the System, you agree to indemnify, defend, and hold harmless the Proje
|
||||
return legal_text
|
||||
|
||||
|
||||
def print_motd(config: Config, logger: logging.Logger):
|
||||
def print_motd(config: "Config", logger: logging.Logger):
|
||||
motd, is_new_motd = get_latest_bulletin()
|
||||
if motd:
|
||||
motd = markdown_to_ansi_style(motd)
|
||||
@@ -162,7 +187,7 @@ def print_motd(config: Config, logger: logging.Logger):
|
||||
},
|
||||
msg=motd_line,
|
||||
)
|
||||
if is_new_motd and not config.chat_messages_enabled:
|
||||
if is_new_motd:
|
||||
input(
|
||||
Fore.MAGENTA
|
||||
+ Style.BRIGHT
|
||||
@@ -188,3 +213,40 @@ def print_python_version_info(logger: logging.Logger):
|
||||
"parts of AutoGPT with this version. "
|
||||
"Please consider upgrading to Python 3.10 or higher.",
|
||||
)
|
||||
|
||||
|
||||
ENV_FILE_PATH = Path(__file__).parent.parent.parent / ".env"
|
||||
|
||||
|
||||
def env_file_exists() -> bool:
|
||||
return ENV_FILE_PATH.is_file()
|
||||
|
||||
|
||||
def set_env_config_value(key: str, value: str) -> None:
|
||||
"""Sets the specified env variable and updates it in .env as well"""
|
||||
os.environ[key] = value
|
||||
|
||||
with ENV_FILE_PATH.open("r+") as file:
|
||||
lines = file.readlines()
|
||||
file.seek(0)
|
||||
key_already_in_file = False
|
||||
for line in lines:
|
||||
if re.match(rf"^(?:# )?{key}=.*$", line):
|
||||
file.write(f"{key}={value}\n")
|
||||
key_already_in_file = True
|
||||
else:
|
||||
file.write(line)
|
||||
|
||||
if not key_already_in_file:
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
file.truncate()
|
||||
|
||||
|
||||
def is_port_free(port: int, host: str = "127.0.0.1"):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind((host, port)) # Try to bind to the port
|
||||
return True # If successful, the port is free
|
||||
except OSError:
|
||||
return False # If failed, the port is likely in use
|
||||
|
||||
@@ -1,12 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, ParamSpec, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.base import BaseAgent
|
||||
from autogpt.config import Config
|
||||
import re
|
||||
from typing import Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.models.command import Command, CommandOutput, CommandParameter
|
||||
@@ -19,19 +12,35 @@ CO = TypeVar("CO", bound=CommandOutput)
|
||||
|
||||
|
||||
def command(
|
||||
name: str,
|
||||
description: str,
|
||||
parameters: dict[str, JSONSchema],
|
||||
enabled: Literal[True] | Callable[[Config], bool] = True,
|
||||
disabled_reason: Optional[str] = None,
|
||||
aliases: list[str] = [],
|
||||
available: Literal[True] | Callable[[BaseAgent], bool] = True,
|
||||
) -> Callable[[Callable[P, CO]], Callable[P, CO]]:
|
||||
names: list[str] = [],
|
||||
description: Optional[str] = None,
|
||||
parameters: dict[str, JSONSchema] = {},
|
||||
) -> Callable[[Callable[P, CommandOutput]], Command]:
|
||||
"""
|
||||
The command decorator is used to create Command objects from ordinary functions.
|
||||
The command decorator is used to make a Command from a function.
|
||||
|
||||
Args:
|
||||
names (list[str]): The names of the command.
|
||||
If not provided, the function name will be used.
|
||||
description (str): A brief description of what the command does.
|
||||
If not provided, the docstring until double line break will be used
|
||||
(or entire docstring if no double line break is found)
|
||||
parameters (dict[str, JSONSchema]): The parameters of the function
|
||||
that the command executes.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, CO]) -> Callable[P, CO]:
|
||||
def decorator(func: Callable[P, CO]) -> Command:
|
||||
doc = func.__doc__ or ""
|
||||
# If names is not provided, use the function name
|
||||
command_names = names or [func.__name__]
|
||||
# If description is not provided, use the first part of the docstring
|
||||
if not (command_description := description):
|
||||
if not func.__doc__:
|
||||
raise ValueError("Description is required if function has no docstring")
|
||||
# Return the part of the docstring before double line break or everything
|
||||
command_description = re.sub(r"\s+", " ", doc.split("\n\n")[0].strip())
|
||||
|
||||
# Parameters
|
||||
typed_parameters = [
|
||||
CommandParameter(
|
||||
name=param_name,
|
||||
@@ -39,32 +48,15 @@ def command(
|
||||
)
|
||||
for param_name, spec in parameters.items()
|
||||
]
|
||||
cmd = Command(
|
||||
name=name,
|
||||
description=description,
|
||||
|
||||
# Wrap func with Command
|
||||
command = Command(
|
||||
names=command_names,
|
||||
description=command_description,
|
||||
method=func,
|
||||
parameters=typed_parameters,
|
||||
enabled=enabled,
|
||||
disabled_reason=disabled_reason,
|
||||
aliases=aliases,
|
||||
available=available,
|
||||
)
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
else:
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
setattr(wrapper, "command", cmd)
|
||||
setattr(wrapper, AUTO_GPT_COMMAND_IDENTIFIER, True)
|
||||
|
||||
return wrapper
|
||||
return command
|
||||
|
||||
return decorator
|
||||
|
||||
128
autogpts/autogpt/autogpt/commands/README.md
Normal file
128
autogpts/autogpt/autogpt/commands/README.md
Normal file
@@ -0,0 +1,128 @@
|
||||
# 🧩 Components
|
||||
|
||||
Components are the building blocks of [🤖 Agents](./agents.md). They are classes inheriting `AgentComponent` or implementing one or more [⚙️ Protocols](./protocols.md) that give agent additional abilities or processing.
|
||||
|
||||
Components can be used to implement various functionalities like providing messages to the prompt, executing code, or interacting with external services.
|
||||
They can be enabled or disabled, ordered, and can rely on each other.
|
||||
|
||||
Components assigned in the agent's `__init__` via `self` are automatically detected upon the agent's instantiation.
|
||||
For example inside `__init__`: `self.my_component = MyComponent()`.
|
||||
You can use any valid Python variable name, what matters for the component to be detected is its type (`AgentComponent` or any protocol inheriting from it).
|
||||
|
||||
Visit [Built-in Components](./built-in-components.md) to see what components are available out of the box.
|
||||
|
||||
```py
|
||||
from autogpt.agents import Agent
|
||||
from autogpt.agents.components import AgentComponent
|
||||
|
||||
class HelloComponent(AgentComponent):
|
||||
pass
|
||||
|
||||
class SomeComponent(AgentComponent):
|
||||
def __init__(self, hello_component: HelloComponent):
|
||||
self.hello_component = hello_component
|
||||
|
||||
class MyAgent(Agent):
|
||||
def __init__(self):
|
||||
# These components will be automatically discovered and used
|
||||
self.hello_component = HelloComponent()
|
||||
# We pass HelloComponent to SomeComponent
|
||||
self.some_component = SomeComponent(self.hello_component)
|
||||
```
|
||||
|
||||
## Ordering components
|
||||
|
||||
The execution order of components is important because the latter ones may depend on the results of the former ones.
|
||||
|
||||
### Implicit order
|
||||
|
||||
Components can be ordered implicitly by the agent; each component can set `run_after` list to specify which components should run before it. This is useful when components rely on each other or need to be executed in a specific order. Otherwise, the order of components is alphabetical.
|
||||
|
||||
```py
|
||||
# This component will run after HelloComponent
|
||||
class CalculatorComponent(AgentComponent):
|
||||
run_after = [HelloComponent]
|
||||
```
|
||||
|
||||
### Explicit order
|
||||
|
||||
Sometimes it may be easier to order components explicitly by setting `self.components` list in the agent's `__init__` method. This way you can also ensure there's no circular dependencies and `run_after` is ignored.
|
||||
|
||||
!!! warning
|
||||
Be sure to include all components - by setting `self.components` list, you're overriding the default behavior of discovering components automatically. Since it's usually not intended agent will inform you in the terminal if some components were skipped.
|
||||
|
||||
```py
|
||||
class MyAgent(Agent):
|
||||
def __init__(self):
|
||||
self.hello_component = HelloComponent()
|
||||
self.calculator_component = CalculatorComponent(self.hello_component)
|
||||
# Explicitly set components list
|
||||
self.components = [self.hello_component, self.calculator_component]
|
||||
```
|
||||
|
||||
## Disabling components
|
||||
|
||||
You can control which components are enabled by setting their `_enabled` attribute.
|
||||
Either provide a `bool` value or a `Callable[[], bool]`, will be checked each time
|
||||
the component is about to be executed. This way you can dynamically enable or disable
|
||||
components based on some conditions.
|
||||
You can also provide a reason for disabling the component by setting `_disabled_reason`.
|
||||
The reason will be visible in the debug information.
|
||||
|
||||
```py
|
||||
class DisabledComponent(MessageProvider):
|
||||
def __init__(self):
|
||||
# Disable this component
|
||||
self._enabled = False
|
||||
self._disabled_reason = "This component is disabled because of reasons."
|
||||
|
||||
# Or disable based on some condition, either statically...:
|
||||
self._enabled = self.some_property is not None
|
||||
# ... or dynamically:
|
||||
self._enabled = lambda: self.some_property is not None
|
||||
|
||||
# This method will never be called
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
yield ChatMessage.user("This message won't be seen!")
|
||||
|
||||
def some_condition(self) -> bool:
|
||||
return False
|
||||
```
|
||||
|
||||
If you don't want the component at all, you can just remove it from the agent's `__init__` method. If you want to remove components you inherit from the parent class you can set the relevant attribute to `None`:
|
||||
|
||||
!!! Warning
|
||||
Be careful when removing components that are required by other components. This may lead to errors and unexpected behavior.
|
||||
|
||||
```py
|
||||
class MyAgent(Agent):
|
||||
def __init__(self):
|
||||
super().__init__(...)
|
||||
# Disable WatchdogComponent that is in the parent class
|
||||
self.watchdog = None
|
||||
|
||||
```
|
||||
|
||||
## Exceptions
|
||||
|
||||
Custom errors are provided which can be used to control the execution flow in case something went wrong. All those errors can be raised in protocol methods and will be caught by the agent.
|
||||
By default agent will retry three times and then re-raise an exception if it's still not resolved. All passed arguments are automatically handled and the values are reverted when needed.
|
||||
All errors accept an optional `str` message. There are following errors ordered by increasing broadness:
|
||||
|
||||
1. `ComponentEndpointError`: A single endpoint method failed to execute. Agent will retry the execution of this endpoint on the component.
|
||||
2. `EndpointPipelineError`: A pipeline failed to execute. Agent will retry the execution of the endpoint for all components.
|
||||
3. `ComponentSystemError`: Multiple pipelines failed.
|
||||
|
||||
**Example**
|
||||
|
||||
```py
|
||||
from autogpt.agents.components import ComponentEndpointError
|
||||
from autogpt.agents.protocols import MessageProvider
|
||||
|
||||
# Example of raising an error
|
||||
class MyComponent(MessageProvider):
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
# This will cause the component to always fail
|
||||
# and retry 3 times before re-raising the exception
|
||||
raise ComponentEndpointError("Endpoint error!")
|
||||
```
|
||||
@@ -1,9 +0,0 @@
|
||||
COMMAND_CATEGORIES = [
|
||||
"autogpt.commands.execute_code",
|
||||
"autogpt.commands.file_operations",
|
||||
"autogpt.commands.user_interaction",
|
||||
"autogpt.commands.web_search",
|
||||
"autogpt.commands.web_selenium",
|
||||
"autogpt.commands.system",
|
||||
"autogpt.commands.image_gen",
|
||||
]
|
||||
@@ -1,82 +0,0 @@
|
||||
import functools
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def sanitize_path_arg(
|
||||
arg_name: str, make_relative: bool = False
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
"""Sanitizes the specified path (str | Path) argument, resolving it to a Path"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
# Get position of path parameter, in case it is passed as a positional argument
|
||||
try:
|
||||
arg_index = list(func.__annotations__.keys()).index(arg_name)
|
||||
except ValueError:
|
||||
raise TypeError(
|
||||
f"Sanitized parameter '{arg_name}' absent or not annotated"
|
||||
f" on function '{func.__name__}'"
|
||||
)
|
||||
|
||||
# Get position of agent parameter, in case it is passed as a positional argument
|
||||
try:
|
||||
agent_arg_index = list(func.__annotations__.keys()).index("agent")
|
||||
except ValueError:
|
||||
raise TypeError(
|
||||
f"Parameter 'agent' absent or not annotated"
|
||||
f" on function '{func.__name__}'"
|
||||
)
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
logger.debug(f"Sanitizing arg '{arg_name}' on function '{func.__name__}'")
|
||||
|
||||
# Get Agent from the called function's arguments
|
||||
agent = kwargs.get(
|
||||
"agent", len(args) > agent_arg_index and args[agent_arg_index]
|
||||
)
|
||||
if not isinstance(agent, Agent):
|
||||
raise RuntimeError("Could not get Agent from decorated command's args")
|
||||
|
||||
# Sanitize the specified path argument, if one is given
|
||||
given_path: str | Path | None = kwargs.get(
|
||||
arg_name, len(args) > arg_index and args[arg_index] or None
|
||||
)
|
||||
if given_path:
|
||||
if type(given_path) is str:
|
||||
# Fix workspace path from output in docker environment
|
||||
given_path = re.sub(r"^\/workspace", ".", given_path)
|
||||
|
||||
if given_path in {"", "/", "."}:
|
||||
sanitized_path = agent.workspace.root
|
||||
else:
|
||||
sanitized_path = agent.workspace.get_path(given_path)
|
||||
|
||||
# Make path relative if possible
|
||||
if make_relative and sanitized_path.is_relative_to(
|
||||
agent.workspace.root
|
||||
):
|
||||
sanitized_path = sanitized_path.relative_to(agent.workspace.root)
|
||||
|
||||
if arg_name in kwargs:
|
||||
kwargs[arg_name] = sanitized_path
|
||||
else:
|
||||
# args is an immutable tuple; must be converted to a list to update
|
||||
arg_list = list(args)
|
||||
arg_list[arg_index] = sanitized_path
|
||||
args = tuple(arg_list)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
@@ -1,31 +1,28 @@
|
||||
"""Commands to execute code"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Iterator
|
||||
|
||||
import docker
|
||||
from docker.errors import DockerException, ImageNotFound, NotFound
|
||||
from docker.models.containers import Container as DockerContainer
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import (
|
||||
from autogpt.agents.base import BaseAgentSettings
|
||||
from autogpt.agents.protocols import CommandProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.utils.exceptions import (
|
||||
CodeExecutionError,
|
||||
CommandExecutionError,
|
||||
InvalidArgumentError,
|
||||
OperationNotAllowedError,
|
||||
)
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
|
||||
COMMAND_CATEGORY = "execute_code"
|
||||
COMMAND_CATEGORY_TITLE = "Execute Code"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,313 +30,6 @@ ALLOWLIST_CONTROL = "allowlist"
|
||||
DENYLIST_CONTROL = "denylist"
|
||||
|
||||
|
||||
@command(
|
||||
"execute_python_code",
|
||||
"Executes the given Python code inside a single-use Docker container"
|
||||
" with access to your workspace folder",
|
||||
{
|
||||
"code": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The Python code to run",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def execute_python_code(code: str, agent: Agent) -> str:
|
||||
"""
|
||||
Create and execute a Python file in a Docker container and return the STDOUT of the
|
||||
executed code.
|
||||
|
||||
If the code generates any data that needs to be captured, use a print statement.
|
||||
|
||||
Args:
|
||||
code (str): The Python code to run.
|
||||
agent (Agent): The Agent executing the command.
|
||||
|
||||
Returns:
|
||||
str: The STDOUT captured from the code when it ran.
|
||||
"""
|
||||
|
||||
tmp_code_file = NamedTemporaryFile(
|
||||
"w", dir=agent.workspace.root, suffix=".py", encoding="utf-8"
|
||||
)
|
||||
tmp_code_file.write(code)
|
||||
tmp_code_file.flush()
|
||||
|
||||
try:
|
||||
return execute_python_file(tmp_code_file.name, agent) # type: ignore
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(*e.args)
|
||||
finally:
|
||||
tmp_code_file.close()
|
||||
|
||||
|
||||
@command(
|
||||
"execute_python_file",
|
||||
"Execute an existing Python file inside a single-use Docker container"
|
||||
" with access to your workspace folder",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The name of the file to execute",
|
||||
required=True,
|
||||
),
|
||||
"args": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
description="The (command line) arguments to pass to the script",
|
||||
required=False,
|
||||
items=JSONSchema(type=JSONSchema.Type.STRING),
|
||||
),
|
||||
},
|
||||
)
|
||||
@sanitize_path_arg("filename")
|
||||
def execute_python_file(
|
||||
filename: Path, agent: Agent, args: list[str] | str = []
|
||||
) -> str:
|
||||
"""Execute a Python file in a Docker container and return the output
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to execute
|
||||
args (list, optional): The arguments with which to run the python script
|
||||
|
||||
Returns:
|
||||
str: The output of the file
|
||||
"""
|
||||
logger.info(
|
||||
f"Executing python file '{filename}' "
|
||||
f"in working directory '{agent.workspace.root}'"
|
||||
)
|
||||
|
||||
if isinstance(args, str):
|
||||
args = args.split() # Convert space-separated string to a list
|
||||
|
||||
if not str(filename).endswith(".py"):
|
||||
raise InvalidArgumentError("Invalid file type. Only .py files are allowed.")
|
||||
|
||||
file_path = filename
|
||||
if not file_path.is_file():
|
||||
# Mimic the response that you get from the command line to make it
|
||||
# intuitively understandable for the LLM
|
||||
raise FileNotFoundError(
|
||||
f"python: can't open file '{filename}': [Errno 2] No such file or directory"
|
||||
)
|
||||
|
||||
if we_are_running_in_a_docker_container():
|
||||
logger.debug(
|
||||
"AutoGPT is running in a Docker container; "
|
||||
f"executing {file_path} directly..."
|
||||
)
|
||||
result = subprocess.run(
|
||||
["python", "-B", str(file_path)] + args,
|
||||
capture_output=True,
|
||||
encoding="utf8",
|
||||
cwd=str(agent.workspace.root),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
else:
|
||||
raise CodeExecutionError(result.stderr)
|
||||
|
||||
logger.debug("AutoGPT is not running in a Docker container")
|
||||
try:
|
||||
assert agent.state.agent_id, "Need Agent ID to attach Docker container"
|
||||
|
||||
client = docker.from_env()
|
||||
# You can replace this with the desired Python image/version
|
||||
# You can find available Python images on Docker Hub:
|
||||
# https://hub.docker.com/_/python
|
||||
image_name = "python:3-alpine"
|
||||
container_is_fresh = False
|
||||
container_name = f"{agent.state.agent_id}_sandbox"
|
||||
try:
|
||||
container: DockerContainer = client.containers.get(
|
||||
container_name
|
||||
) # type: ignore
|
||||
except NotFound:
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
logger.debug(f"Image '{image_name}' found locally")
|
||||
except ImageNotFound:
|
||||
logger.info(
|
||||
f"Image '{image_name}' not found locally,"
|
||||
" pulling from Docker Hub..."
|
||||
)
|
||||
# Use the low-level API to stream the pull response
|
||||
low_level_client = docker.APIClient()
|
||||
for line in low_level_client.pull(image_name, stream=True, decode=True):
|
||||
# Print the status and progress, if available
|
||||
status = line.get("status")
|
||||
progress = line.get("progress")
|
||||
if status and progress:
|
||||
logger.info(f"{status}: {progress}")
|
||||
elif status:
|
||||
logger.info(status)
|
||||
|
||||
logger.debug(f"Creating new {image_name} container...")
|
||||
container: DockerContainer = client.containers.run(
|
||||
image_name,
|
||||
["sleep", "60"], # Max 60 seconds to prevent permanent hangs
|
||||
volumes={
|
||||
str(agent.workspace.root): {
|
||||
"bind": "/workspace",
|
||||
"mode": "rw",
|
||||
}
|
||||
},
|
||||
working_dir="/workspace",
|
||||
stderr=True,
|
||||
stdout=True,
|
||||
detach=True,
|
||||
name=container_name,
|
||||
) # type: ignore
|
||||
container_is_fresh = True
|
||||
|
||||
if not container.status == "running":
|
||||
container.start()
|
||||
elif not container_is_fresh:
|
||||
container.restart()
|
||||
|
||||
logger.debug(f"Running {file_path} in container {container.name}...")
|
||||
exec_result = container.exec_run(
|
||||
[
|
||||
"python",
|
||||
"-B",
|
||||
file_path.relative_to(agent.workspace.root).as_posix(),
|
||||
]
|
||||
+ args,
|
||||
stderr=True,
|
||||
stdout=True,
|
||||
)
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
raise CodeExecutionError(exec_result.output.decode("utf-8"))
|
||||
|
||||
return exec_result.output.decode("utf-8")
|
||||
|
||||
except DockerException as e:
|
||||
logger.warning(
|
||||
"Could not run the script in a container. "
|
||||
"If you haven't already, please install Docker: "
|
||||
"https://docs.docker.com/get-docker/"
|
||||
)
|
||||
raise CommandExecutionError(f"Could not run the script in a container: {e}")
|
||||
|
||||
|
||||
def validate_command(command: str, config: Config) -> bool:
|
||||
"""Validate a command to ensure it is allowed
|
||||
|
||||
Args:
|
||||
command (str): The command to validate
|
||||
config (Config): The config to use to validate the command
|
||||
|
||||
Returns:
|
||||
bool: True if the command is allowed, False otherwise
|
||||
"""
|
||||
if not command:
|
||||
return False
|
||||
|
||||
command_name = command.split()[0]
|
||||
|
||||
if config.shell_command_control == ALLOWLIST_CONTROL:
|
||||
return command_name in config.shell_allowlist
|
||||
else:
|
||||
return command_name not in config.shell_denylist
|
||||
|
||||
|
||||
@command(
|
||||
"execute_shell",
|
||||
"Execute a Shell Command, non-interactive commands only",
|
||||
{
|
||||
"command_line": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The command line to execute",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
enabled=lambda config: config.execute_local_commands,
|
||||
disabled_reason="You are not allowed to run local shell commands. To execute"
|
||||
" shell commands, EXECUTE_LOCAL_COMMANDS must be set to 'True' "
|
||||
"in your config file: .env - do not attempt to bypass the restriction.",
|
||||
)
|
||||
def execute_shell(command_line: str, agent: Agent) -> str:
|
||||
"""Execute a shell command and return the output
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to execute
|
||||
|
||||
Returns:
|
||||
str: The output of the command
|
||||
"""
|
||||
if not validate_command(command_line, agent.legacy_config):
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
|
||||
current_dir = Path.cwd()
|
||||
# Change dir into workspace if necessary
|
||||
if not current_dir.is_relative_to(agent.workspace.root):
|
||||
os.chdir(agent.workspace.root)
|
||||
|
||||
logger.info(
|
||||
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
|
||||
)
|
||||
|
||||
result = subprocess.run(command_line, capture_output=True, shell=True)
|
||||
output = f"STDOUT:\n{result.stdout.decode()}\nSTDERR:\n{result.stderr.decode()}"
|
||||
|
||||
# Change back to whatever the prior working dir was
|
||||
os.chdir(current_dir)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@command(
|
||||
"execute_shell_popen",
|
||||
"Execute a Shell Command, non-interactive commands only",
|
||||
{
|
||||
"command_line": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The command line to execute",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
lambda config: config.execute_local_commands,
|
||||
"You are not allowed to run local shell commands. To execute"
|
||||
" shell commands, EXECUTE_LOCAL_COMMANDS must be set to 'True' "
|
||||
"in your config. Do not attempt to bypass the restriction.",
|
||||
)
|
||||
def execute_shell_popen(command_line: str, agent: Agent) -> str:
|
||||
"""Execute a shell command with Popen and returns an english description
|
||||
of the event and the process id
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to execute
|
||||
|
||||
Returns:
|
||||
str: Description of the fact that the process started and its id
|
||||
"""
|
||||
if not validate_command(command_line, agent.legacy_config):
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
|
||||
current_dir = Path.cwd()
|
||||
# Change dir into workspace if necessary
|
||||
if not current_dir.is_relative_to(agent.workspace.root):
|
||||
os.chdir(agent.workspace.root)
|
||||
|
||||
logger.info(
|
||||
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
|
||||
)
|
||||
|
||||
do_not_show_output = subprocess.DEVNULL
|
||||
process = subprocess.Popen(
|
||||
command_line, shell=True, stdout=do_not_show_output, stderr=do_not_show_output
|
||||
)
|
||||
|
||||
# Change back to whatever the prior working dir was
|
||||
os.chdir(current_dir)
|
||||
|
||||
return f"Subprocess started with PID:'{str(process.pid)}'"
|
||||
|
||||
|
||||
def we_are_running_in_a_docker_container() -> bool:
|
||||
"""Check if we are running in a Docker container
|
||||
|
||||
@@ -347,3 +37,360 @@ def we_are_running_in_a_docker_container() -> bool:
|
||||
bool: True if we are running in a Docker container, False otherwise
|
||||
"""
|
||||
return os.path.exists("/.dockerenv")
|
||||
|
||||
|
||||
def is_docker_available() -> bool:
|
||||
"""Check if Docker is available and supports Linux containers
|
||||
|
||||
Returns:
|
||||
bool: True if Docker is available and supports Linux containers, False otherwise
|
||||
"""
|
||||
try:
|
||||
client = docker.from_env()
|
||||
docker_info = client.info()
|
||||
return docker_info["OSType"] == "linux"
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
class CodeExecutorComponent(CommandProvider):
|
||||
"""Provides commands to execute Python code and shell commands."""
|
||||
|
||||
def __init__(
|
||||
self, workspace: FileStorage, state: BaseAgentSettings, config: Config
|
||||
):
|
||||
self.workspace = workspace
|
||||
self.state = state
|
||||
self.legacy_config = config
|
||||
|
||||
if not we_are_running_in_a_docker_container() and not is_docker_available():
|
||||
logger.info(
|
||||
"Docker is not available or does not support Linux containers. "
|
||||
"The code execution commands will not be available."
|
||||
)
|
||||
|
||||
if not self.legacy_config.execute_local_commands:
|
||||
logger.info(
|
||||
"Local shell commands are disabled. To enable them,"
|
||||
" set EXECUTE_LOCAL_COMMANDS to 'True' in your config file."
|
||||
)
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
if we_are_running_in_a_docker_container() or is_docker_available():
|
||||
yield self.execute_python_code
|
||||
yield self.execute_python_file
|
||||
|
||||
if self.legacy_config.execute_local_commands:
|
||||
yield self.execute_shell
|
||||
yield self.execute_shell_popen
|
||||
|
||||
@command(
|
||||
["execute_python_code"],
|
||||
"Executes the given Python code inside a single-use Docker container"
|
||||
" with access to your workspace folder",
|
||||
{
|
||||
"code": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The Python code to run",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def execute_python_code(self, code: str) -> str:
|
||||
"""
|
||||
Create and execute a Python file in a Docker container
|
||||
and return the STDOUT of the executed code.
|
||||
|
||||
If the code generates any data that needs to be captured,
|
||||
use a print statement.
|
||||
|
||||
Args:
|
||||
code (str): The Python code to run.
|
||||
agent (Agent): The Agent executing the command.
|
||||
|
||||
Returns:
|
||||
str: The STDOUT captured from the code when it ran.
|
||||
"""
|
||||
|
||||
tmp_code_file = NamedTemporaryFile(
|
||||
"w", dir=self.workspace.root, suffix=".py", encoding="utf-8"
|
||||
)
|
||||
tmp_code_file.write(code)
|
||||
tmp_code_file.flush()
|
||||
|
||||
try:
|
||||
return self.execute_python_file(tmp_code_file.name)
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(*e.args)
|
||||
finally:
|
||||
tmp_code_file.close()
|
||||
|
||||
@command(
|
||||
["execute_python_file"],
|
||||
"Execute an existing Python file inside a single-use Docker container"
|
||||
" with access to your workspace folder",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The name of the file to execute",
|
||||
required=True,
|
||||
),
|
||||
"args": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
description="The (command line) arguments to pass to the script",
|
||||
required=False,
|
||||
items=JSONSchema(type=JSONSchema.Type.STRING),
|
||||
),
|
||||
},
|
||||
)
|
||||
def execute_python_file(self, filename: str, args: list[str] | str = []) -> str:
|
||||
"""Execute a Python file in a Docker container and return the output
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to execute
|
||||
args (list, optional): The arguments with which to run the python script
|
||||
|
||||
Returns:
|
||||
str: The output of the file
|
||||
"""
|
||||
logger.info(
|
||||
f"Executing python file '{filename}' "
|
||||
f"in working directory '{self.workspace.root}'"
|
||||
)
|
||||
|
||||
if isinstance(args, str):
|
||||
args = args.split() # Convert space-separated string to a list
|
||||
|
||||
if not str(filename).endswith(".py"):
|
||||
raise InvalidArgumentError("Invalid file type. Only .py files are allowed.")
|
||||
|
||||
file_path = self.workspace.get_path(filename)
|
||||
if not self.workspace.exists(file_path):
|
||||
# Mimic the response that you get from the command line to make it
|
||||
# intuitively understandable for the LLM
|
||||
raise FileNotFoundError(
|
||||
f"python: can't open file '{filename}': "
|
||||
f"[Errno 2] No such file or directory"
|
||||
)
|
||||
|
||||
if we_are_running_in_a_docker_container():
|
||||
logger.debug(
|
||||
"AutoGPT is running in a Docker container; "
|
||||
f"executing {file_path} directly..."
|
||||
)
|
||||
result = subprocess.run(
|
||||
["python", "-B", str(file_path)] + args,
|
||||
capture_output=True,
|
||||
encoding="utf8",
|
||||
cwd=str(self.workspace.root),
|
||||
)
|
||||
if result.returncode == 0:
|
||||
return result.stdout
|
||||
else:
|
||||
raise CodeExecutionError(result.stderr)
|
||||
|
||||
logger.debug("AutoGPT is not running in a Docker container")
|
||||
try:
|
||||
assert self.state.agent_id, "Need Agent ID to attach Docker container"
|
||||
|
||||
client = docker.from_env()
|
||||
image_name = "python:3-alpine"
|
||||
container_is_fresh = False
|
||||
container_name = f"{self.state.agent_id}_sandbox"
|
||||
try:
|
||||
container: DockerContainer = client.containers.get(
|
||||
container_name
|
||||
) # type: ignore
|
||||
except NotFound:
|
||||
try:
|
||||
client.images.get(image_name)
|
||||
logger.debug(f"Image '{image_name}' found locally")
|
||||
except ImageNotFound:
|
||||
logger.info(
|
||||
f"Image '{image_name}' not found locally,"
|
||||
" pulling from Docker Hub..."
|
||||
)
|
||||
# Use the low-level API to stream the pull response
|
||||
low_level_client = docker.APIClient()
|
||||
for line in low_level_client.pull(
|
||||
image_name, stream=True, decode=True
|
||||
):
|
||||
# Print the status and progress, if available
|
||||
status = line.get("status")
|
||||
progress = line.get("progress")
|
||||
if status and progress:
|
||||
logger.info(f"{status}: {progress}")
|
||||
elif status:
|
||||
logger.info(status)
|
||||
|
||||
logger.debug(f"Creating new {image_name} container...")
|
||||
container: DockerContainer = client.containers.run(
|
||||
image_name,
|
||||
["sleep", "60"], # Max 60 seconds to prevent permanent hangs
|
||||
volumes={
|
||||
str(self.workspace.root): {
|
||||
"bind": "/workspace",
|
||||
"mode": "rw",
|
||||
}
|
||||
},
|
||||
working_dir="/workspace",
|
||||
stderr=True,
|
||||
stdout=True,
|
||||
detach=True,
|
||||
name=container_name,
|
||||
) # type: ignore
|
||||
container_is_fresh = True
|
||||
|
||||
if not container.status == "running":
|
||||
container.start()
|
||||
elif not container_is_fresh:
|
||||
container.restart()
|
||||
|
||||
logger.debug(f"Running {file_path} in container {container.name}...")
|
||||
exec_result = container.exec_run(
|
||||
[
|
||||
"python",
|
||||
"-B",
|
||||
file_path.relative_to(self.workspace.root).as_posix(),
|
||||
]
|
||||
+ args,
|
||||
stderr=True,
|
||||
stdout=True,
|
||||
)
|
||||
|
||||
if exec_result.exit_code != 0:
|
||||
raise CodeExecutionError(exec_result.output.decode("utf-8"))
|
||||
|
||||
return exec_result.output.decode("utf-8")
|
||||
|
||||
except DockerException as e:
|
||||
logger.warning(
|
||||
"Could not run the script in a container. "
|
||||
"If you haven't already, please install Docker: "
|
||||
"https://docs.docker.com/get-docker/"
|
||||
)
|
||||
raise CommandExecutionError(f"Could not run the script in a container: {e}")
|
||||
|
||||
def validate_command(self, command_line: str, config: Config) -> tuple[bool, bool]:
|
||||
"""Check whether a command is allowed and whether it may be executed in a shell.
|
||||
|
||||
If shell command control is enabled, we disallow executing in a shell, because
|
||||
otherwise the model could circumvent the command filter using shell features.
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to validate
|
||||
config (Config): The app config including shell command control settings
|
||||
|
||||
Returns:
|
||||
bool: True if the command is allowed, False otherwise
|
||||
bool: True if the command may be executed in a shell, False otherwise
|
||||
"""
|
||||
if not command_line:
|
||||
return False, False
|
||||
|
||||
command_name = shlex.split(command_line)[0]
|
||||
|
||||
if config.shell_command_control == ALLOWLIST_CONTROL:
|
||||
return command_name in config.shell_allowlist, False
|
||||
elif config.shell_command_control == DENYLIST_CONTROL:
|
||||
return command_name not in config.shell_denylist, False
|
||||
else:
|
||||
return True, True
|
||||
|
||||
@command(
|
||||
["execute_shell"],
|
||||
"Execute a Shell Command, non-interactive commands only",
|
||||
{
|
||||
"command_line": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The command line to execute",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def execute_shell(self, command_line: str) -> str:
|
||||
"""Execute a shell command and return the output
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to execute
|
||||
|
||||
Returns:
|
||||
str: The output of the command
|
||||
"""
|
||||
allow_execute, allow_shell = self.validate_command(
|
||||
command_line, self.legacy_config
|
||||
)
|
||||
if not allow_execute:
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
|
||||
current_dir = Path.cwd()
|
||||
# Change dir into workspace if necessary
|
||||
if not current_dir.is_relative_to(self.workspace.root):
|
||||
os.chdir(self.workspace.root)
|
||||
|
||||
logger.info(
|
||||
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
|
||||
)
|
||||
|
||||
result = subprocess.run(
|
||||
command_line if allow_shell else shlex.split(command_line),
|
||||
capture_output=True,
|
||||
shell=allow_shell,
|
||||
)
|
||||
output = f"STDOUT:\n{result.stdout.decode()}\nSTDERR:\n{result.stderr.decode()}"
|
||||
|
||||
# Change back to whatever the prior working dir was
|
||||
os.chdir(current_dir)
|
||||
|
||||
return output
|
||||
|
||||
@command(
|
||||
["execute_shell_popen"],
|
||||
"Execute a Shell Command, non-interactive commands only",
|
||||
{
|
||||
"command_line": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The command line to execute",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def execute_shell_popen(self, command_line: str) -> str:
|
||||
"""Execute a shell command with Popen and returns an english description
|
||||
of the event and the process id
|
||||
|
||||
Args:
|
||||
command_line (str): The command line to execute
|
||||
|
||||
Returns:
|
||||
str: Description of the fact that the process started and its id
|
||||
"""
|
||||
allow_execute, allow_shell = self.validate_command(
|
||||
command_line, self.legacy_config
|
||||
)
|
||||
if not allow_execute:
|
||||
logger.info(f"Command '{command_line}' not allowed")
|
||||
raise OperationNotAllowedError("This shell command is not allowed.")
|
||||
|
||||
current_dir = Path.cwd()
|
||||
# Change dir into workspace if necessary
|
||||
if not current_dir.is_relative_to(self.workspace.root):
|
||||
os.chdir(self.workspace.root)
|
||||
|
||||
logger.info(
|
||||
f"Executing command '{command_line}' in working directory '{os.getcwd()}'"
|
||||
)
|
||||
|
||||
do_not_show_output = subprocess.DEVNULL
|
||||
process = subprocess.Popen(
|
||||
command_line if allow_shell else shlex.split(command_line),
|
||||
shell=allow_shell,
|
||||
stdout=do_not_show_output,
|
||||
stderr=do_not_show_output,
|
||||
)
|
||||
|
||||
# Change back to whatever the prior working dir was
|
||||
os.chdir(current_dir)
|
||||
|
||||
return f"Subprocess started with PID:'{str(process.pid)}'"
|
||||
|
||||
@@ -1,131 +0,0 @@
|
||||
"""Commands to perform operations on files"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from autogpt.agents.features.context import ContextMixin, get_agent_context
|
||||
from autogpt.agents.utils.exceptions import (
|
||||
CommandExecutionError,
|
||||
DuplicateOperationError,
|
||||
)
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.models.context_item import FileContextItem, FolderContextItem
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
|
||||
COMMAND_CATEGORY = "file_operations"
|
||||
COMMAND_CATEGORY_TITLE = "File Operations"
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents import Agent, BaseAgent
|
||||
|
||||
|
||||
def agent_implements_context(agent: BaseAgent) -> bool:
|
||||
return isinstance(agent, ContextMixin)
|
||||
|
||||
|
||||
@command(
|
||||
"open_file",
|
||||
"Opens a file for editing or continued viewing;"
|
||||
" creates it if it does not exist yet. "
|
||||
"Note: If you only need to read or write a file once, use `write_to_file` instead.",
|
||||
{
|
||||
"file_path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the file to open",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
available=agent_implements_context,
|
||||
)
|
||||
@sanitize_path_arg("file_path")
|
||||
def open_file(file_path: Path, agent: Agent) -> tuple[str, FileContextItem]:
|
||||
"""Open a file and return a context item
|
||||
|
||||
Args:
|
||||
file_path (Path): The path of the file to open
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
FileContextItem: A ContextItem representing the opened file
|
||||
"""
|
||||
# Try to make the file path relative
|
||||
relative_file_path = None
|
||||
with contextlib.suppress(ValueError):
|
||||
relative_file_path = file_path.relative_to(agent.workspace.root)
|
||||
|
||||
assert (agent_context := get_agent_context(agent)) is not None
|
||||
|
||||
created = False
|
||||
if not file_path.exists():
|
||||
file_path.touch()
|
||||
created = True
|
||||
elif not file_path.is_file():
|
||||
raise CommandExecutionError(f"{file_path} exists but is not a file")
|
||||
|
||||
file_path = relative_file_path or file_path
|
||||
|
||||
file = FileContextItem(
|
||||
file_path_in_workspace=file_path,
|
||||
workspace_path=agent.workspace.root,
|
||||
)
|
||||
if file in agent_context:
|
||||
raise DuplicateOperationError(f"The file {file_path} is already open")
|
||||
|
||||
return (
|
||||
f"File {file_path}{' created,' if created else ''} has been opened"
|
||||
" and added to the context ✅",
|
||||
file,
|
||||
)
|
||||
|
||||
|
||||
@command(
|
||||
"open_folder",
|
||||
"Open a folder to keep track of its content",
|
||||
{
|
||||
"path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the folder to open",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
available=agent_implements_context,
|
||||
)
|
||||
@sanitize_path_arg("path")
|
||||
def open_folder(path: Path, agent: Agent) -> tuple[str, FolderContextItem]:
|
||||
"""Open a folder and return a context item
|
||||
|
||||
Args:
|
||||
path (Path): The path of the folder to open
|
||||
|
||||
Returns:
|
||||
str: A status message indicating what happened
|
||||
FolderContextItem: A ContextItem representing the opened folder
|
||||
"""
|
||||
# Try to make the path relative
|
||||
relative_path = None
|
||||
with contextlib.suppress(ValueError):
|
||||
relative_path = path.relative_to(agent.workspace.root)
|
||||
|
||||
assert (agent_context := get_agent_context(agent)) is not None
|
||||
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"open_folder {path} failed: no such file or directory")
|
||||
elif not path.is_dir():
|
||||
raise CommandExecutionError(f"{path} exists but is not a folder")
|
||||
|
||||
path = relative_path or path
|
||||
|
||||
folder = FolderContextItem(
|
||||
path_in_workspace=path,
|
||||
workspace_path=agent.workspace.root,
|
||||
)
|
||||
if folder in agent_context:
|
||||
raise DuplicateOperationError(f"The folder {path} is already open")
|
||||
|
||||
return f"Folder {path} has been opened and added to the context ✅", folder
|
||||
@@ -1,268 +0,0 @@
|
||||
"""Commands to perform operations on files"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import os.path
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Literal
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import DuplicateOperationError
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.memory.vector import MemoryItemFactory, VectorMemory
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
from .file_operations_utils import decode_textual_file
|
||||
|
||||
COMMAND_CATEGORY = "file_operations"
|
||||
COMMAND_CATEGORY_TITLE = "File Operations"
|
||||
|
||||
|
||||
from .file_context import open_file, open_folder # NOQA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Operation = Literal["write", "append", "delete"]
|
||||
|
||||
|
||||
def text_checksum(text: str) -> str:
|
||||
"""Get the hex checksum for the given text."""
|
||||
return hashlib.md5(text.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def operations_from_log(
|
||||
log_path: str | Path,
|
||||
) -> Iterator[
|
||||
tuple[Literal["write", "append"], str, str] | tuple[Literal["delete"], str, None]
|
||||
]:
|
||||
"""Parse the file operations log and return a tuple containing the log entries"""
|
||||
try:
|
||||
log = open(log_path, "r", encoding="utf-8")
|
||||
except FileNotFoundError:
|
||||
return
|
||||
|
||||
for line in log:
|
||||
line = line.replace("File Operation Logger", "").strip()
|
||||
if not line:
|
||||
continue
|
||||
operation, tail = line.split(": ", maxsplit=1)
|
||||
operation = operation.strip()
|
||||
if operation in ("write", "append"):
|
||||
path, checksum = (x.strip() for x in tail.rsplit(" #", maxsplit=1))
|
||||
yield (operation, path, checksum)
|
||||
elif operation == "delete":
|
||||
yield (operation, tail.strip(), None)
|
||||
|
||||
log.close()
|
||||
|
||||
|
||||
def file_operations_state(log_path: str | Path) -> dict[str, str]:
|
||||
"""Iterates over the operations log and returns the expected state.
|
||||
|
||||
Parses a log file at file_manager.file_ops_log_path to construct a dictionary
|
||||
that maps each file path written or appended to its checksum. Deleted files are
|
||||
removed from the dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping file paths to their checksums.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file_manager.file_ops_log_path is not found.
|
||||
ValueError: If the log file content is not in the expected format.
|
||||
"""
|
||||
state = {}
|
||||
for operation, path, checksum in operations_from_log(log_path):
|
||||
if operation in ("write", "append"):
|
||||
state[path] = checksum
|
||||
elif operation == "delete":
|
||||
del state[path]
|
||||
return state
|
||||
|
||||
|
||||
@sanitize_path_arg("file_path", make_relative=True)
|
||||
def is_duplicate_operation(
|
||||
operation: Operation, file_path: Path, agent: Agent, checksum: str | None = None
|
||||
) -> bool:
|
||||
"""Check if the operation has already been performed
|
||||
|
||||
Args:
|
||||
operation: The operation to check for
|
||||
file_path: The name of the file to check for
|
||||
agent: The agent
|
||||
checksum: The checksum of the contents to be written
|
||||
|
||||
Returns:
|
||||
True if the operation has already been performed on the file
|
||||
"""
|
||||
state = file_operations_state(agent.file_manager.file_ops_log_path)
|
||||
if operation == "delete" and str(file_path) not in state:
|
||||
return True
|
||||
if operation == "write" and state.get(str(file_path)) == checksum:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@sanitize_path_arg("file_path", make_relative=True)
|
||||
def log_operation(
|
||||
operation: Operation,
|
||||
file_path: str | Path,
|
||||
agent: Agent,
|
||||
checksum: str | None = None,
|
||||
) -> None:
|
||||
"""Log the file operation to the file_logger.log
|
||||
|
||||
Args:
|
||||
operation: The operation to log
|
||||
file_path: The name of the file the operation was performed on
|
||||
checksum: The checksum of the contents to be written
|
||||
"""
|
||||
log_entry = f"{operation}: {file_path}"
|
||||
if checksum is not None:
|
||||
log_entry += f" #{checksum}"
|
||||
logger.debug(f"Logging file operation: {log_entry}")
|
||||
append_to_file(
|
||||
agent.file_manager.file_ops_log_path, f"{log_entry}\n", agent, should_log=False
|
||||
)
|
||||
|
||||
|
||||
@command(
|
||||
"read_file",
|
||||
"Read an existing file",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path of the file to read",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def read_file(filename: str | Path, agent: Agent) -> str:
|
||||
"""Read a file and return the contents
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to read
|
||||
|
||||
Returns:
|
||||
str: The contents of the file
|
||||
"""
|
||||
file = agent.workspace.open_file(filename, binary=True)
|
||||
content = decode_textual_file(file, os.path.splitext(filename)[1], logger)
|
||||
|
||||
# # TODO: invalidate/update memory when file is edited
|
||||
# file_memory = MemoryItem.from_text_file(content, str(filename), agent.config)
|
||||
# if len(file_memory.chunks) > 1:
|
||||
# return file_memory.summary
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def ingest_file(
|
||||
filename: str,
|
||||
memory: VectorMemory,
|
||||
) -> None:
|
||||
"""
|
||||
Ingest a file by reading its content, splitting it into chunks with a specified
|
||||
maximum length and overlap, and adding the chunks to the memory storage.
|
||||
|
||||
Args:
|
||||
filename: The name of the file to ingest
|
||||
memory: An object with an add() method to store the chunks in memory
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Ingesting file {filename}")
|
||||
content = read_file(filename)
|
||||
|
||||
# TODO: differentiate between different types of files
|
||||
file_memory = MemoryItemFactory.from_text_file(content, filename)
|
||||
logger.debug(f"Created memory: {file_memory.dump(True)}")
|
||||
memory.add(file_memory)
|
||||
|
||||
logger.info(f"Ingested {len(file_memory.e_chunks)} chunks from {filename}")
|
||||
except Exception as err:
|
||||
logger.warning(f"Error while ingesting file '{filename}': {err}")
|
||||
|
||||
|
||||
@command(
|
||||
"write_file",
|
||||
"Write a file, creating it if necessary. If the file exists, it is overwritten.",
|
||||
{
|
||||
"filename": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The name of the file to write to",
|
||||
required=True,
|
||||
),
|
||||
"contents": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The contents to write to the file",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
aliases=["create_file"],
|
||||
)
|
||||
async def write_to_file(filename: str | Path, contents: str, agent: Agent) -> str:
|
||||
"""Write contents to a file
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to write to
|
||||
contents (str): The contents to write to the file
|
||||
|
||||
Returns:
|
||||
str: A message indicating success or failure
|
||||
"""
|
||||
checksum = text_checksum(contents)
|
||||
if is_duplicate_operation("write", Path(filename), agent, checksum):
|
||||
raise DuplicateOperationError(f"File {filename} has already been updated.")
|
||||
|
||||
if directory := os.path.dirname(filename):
|
||||
agent.workspace.get_path(directory).mkdir(exist_ok=True)
|
||||
await agent.workspace.write_file(filename, contents)
|
||||
log_operation("write", filename, agent, checksum)
|
||||
return f"File {filename} has been written successfully."
|
||||
|
||||
|
||||
def append_to_file(
|
||||
filename: Path, text: str, agent: Agent, should_log: bool = True
|
||||
) -> None:
|
||||
"""Append text to a file
|
||||
|
||||
Args:
|
||||
filename (Path): The name of the file to append to
|
||||
text (str): The text to append to the file
|
||||
should_log (bool): Should log output
|
||||
"""
|
||||
directory = os.path.dirname(filename)
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
with open(filename, "a") as f:
|
||||
f.write(text)
|
||||
|
||||
if should_log:
|
||||
with open(filename, "r") as f:
|
||||
checksum = text_checksum(f.read())
|
||||
log_operation("append", filename, agent, checksum=checksum)
|
||||
|
||||
|
||||
@command(
|
||||
"list_folder",
|
||||
"List the items in a folder",
|
||||
{
|
||||
"folder": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The folder to list files in",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def list_folder(folder: str | Path, agent: Agent) -> list[str]:
|
||||
"""Lists files in a folder recursively
|
||||
|
||||
Args:
|
||||
folder (Path): The folder to search in
|
||||
|
||||
Returns:
|
||||
list[str]: A list of files found in the folder
|
||||
"""
|
||||
return [str(p) for p in agent.workspace.list(folder)]
|
||||
@@ -1,58 +1,61 @@
|
||||
"""Commands to perform Git operations"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from git.repo import Repo
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import CommandExecutionError
|
||||
from autogpt.agents.protocols import CommandProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.url_utils.validators import validate_url
|
||||
|
||||
from .decorators import sanitize_path_arg
|
||||
|
||||
COMMAND_CATEGORY = "git_operations"
|
||||
COMMAND_CATEGORY_TITLE = "Git Operations"
|
||||
from autogpt.utils.exceptions import CommandExecutionError
|
||||
|
||||
|
||||
@command(
|
||||
"clone_repository",
|
||||
"Clones a Repository",
|
||||
{
|
||||
"url": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The URL of the repository to clone",
|
||||
required=True,
|
||||
),
|
||||
"clone_path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path to clone the repository to",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
lambda config: bool(config.github_username and config.github_api_key),
|
||||
"Configure github_username and github_api_key.",
|
||||
)
|
||||
@sanitize_path_arg("clone_path")
|
||||
@validate_url
|
||||
def clone_repository(url: str, clone_path: Path, agent: Agent) -> str:
|
||||
"""Clone a GitHub repository locally.
|
||||
class GitOperationsComponent(CommandProvider):
|
||||
"""Provides commands to perform Git operations."""
|
||||
|
||||
Args:
|
||||
url (str): The URL of the repository to clone.
|
||||
clone_path (Path): The path to clone the repository to.
|
||||
def __init__(self, config: Config):
|
||||
self._enabled = bool(config.github_username and config.github_api_key)
|
||||
self._disabled_reason = "Configure github_username and github_api_key."
|
||||
self.legacy_config = config
|
||||
|
||||
Returns:
|
||||
str: The result of the clone operation.
|
||||
"""
|
||||
split_url = url.split("//")
|
||||
auth_repo_url = f"//{agent.legacy_config.github_username}:{agent.legacy_config.github_api_key}@".join( # noqa: E501
|
||||
split_url
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.clone_repository
|
||||
|
||||
@command(
|
||||
parameters={
|
||||
"url": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The URL of the repository to clone",
|
||||
required=True,
|
||||
),
|
||||
"clone_path": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The path to clone the repository to",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
try:
|
||||
Repo.clone_from(url=auth_repo_url, to_path=clone_path)
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(f"Could not clone repo: {e}")
|
||||
@validate_url
|
||||
def clone_repository(self, url: str, clone_path: Path) -> str:
|
||||
"""Clone a GitHub repository locally.
|
||||
|
||||
return f"""Cloned {url} to {clone_path}"""
|
||||
Args:
|
||||
url (str): The URL of the repository to clone.
|
||||
clone_path (Path): The path to clone the repository to.
|
||||
|
||||
Returns:
|
||||
str: The result of the clone operation.
|
||||
"""
|
||||
split_url = url.split("//")
|
||||
auth_repo_url = (
|
||||
f"//{self.legacy_config.github_username}:"
|
||||
f"{self.legacy_config.github_api_key}@".join(split_url)
|
||||
)
|
||||
try:
|
||||
Repo.clone_from(url=auth_repo_url, to_path=clone_path)
|
||||
except Exception as e:
|
||||
raise CommandExecutionError(f"Could not clone repo: {e}")
|
||||
|
||||
return f"""Cloned {url} to {clone_path}"""
|
||||
|
||||
@@ -7,206 +7,216 @@ import time
|
||||
import uuid
|
||||
from base64 import b64decode
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from PIL import Image
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.protocols import CommandProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
COMMAND_CATEGORY = "text_to_image"
|
||||
COMMAND_CATEGORY_TITLE = "Text to Image"
|
||||
|
||||
from autogpt.file_storage.base import FileStorage
|
||||
from autogpt.models.command import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"generate_image",
|
||||
"Generates an Image",
|
||||
{
|
||||
"prompt": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The prompt used to generate the image",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
lambda config: bool(config.image_provider),
|
||||
"Requires a image provider to be set.",
|
||||
)
|
||||
def generate_image(prompt: str, agent: Agent, size: int = 256) -> str:
|
||||
"""Generate an image from a prompt.
|
||||
class ImageGeneratorComponent(CommandProvider):
|
||||
"""A component that provides commands to generate images from text prompts."""
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
Not supported by HuggingFace.
|
||||
def __init__(self, workspace: FileStorage, config: Config):
|
||||
self._enabled = bool(config.image_provider)
|
||||
self._disabled_reason = "No image provider set."
|
||||
self.workspace = workspace
|
||||
self.legacy_config = config
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = agent.workspace.root / f"{str(uuid.uuid4())}.jpg"
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.generate_image
|
||||
|
||||
# DALL-E
|
||||
if agent.legacy_config.image_provider == "dalle":
|
||||
return generate_image_with_dalle(prompt, filename, size, agent)
|
||||
# HuggingFace
|
||||
elif agent.legacy_config.image_provider == "huggingface":
|
||||
return generate_image_with_hf(prompt, filename, agent)
|
||||
# SD WebUI
|
||||
elif agent.legacy_config.image_provider == "sdwebui":
|
||||
return generate_image_with_sd_webui(prompt, filename, agent, size)
|
||||
return "No Image Provider Set"
|
||||
@command(
|
||||
parameters={
|
||||
"prompt": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The prompt used to generate the image",
|
||||
required=True,
|
||||
),
|
||||
"size": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The size of the image",
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def generate_image(self, prompt: str, size: int) -> str:
|
||||
"""Generate an image from a prompt.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
Not supported by HuggingFace.
|
||||
|
||||
def generate_image_with_hf(prompt: str, output_file: Path, agent: Agent) -> str:
|
||||
"""Generate an image with HuggingFace's API.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
filename = self.workspace.root / f"{str(uuid.uuid4())}.jpg"
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (Path): The filename to save the image to
|
||||
# DALL-E
|
||||
if self.legacy_config.image_provider == "dalle":
|
||||
return self.generate_image_with_dalle(prompt, filename, size)
|
||||
# HuggingFace
|
||||
elif self.legacy_config.image_provider == "huggingface":
|
||||
return self.generate_image_with_hf(prompt, filename)
|
||||
# SD WebUI
|
||||
elif self.legacy_config.image_provider == "sdwebui":
|
||||
return self.generate_image_with_sd_webui(prompt, filename, size)
|
||||
return "No Image Provider Set"
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{agent.legacy_config.huggingface_image_model}" # noqa: E501
|
||||
if agent.legacy_config.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
def generate_image_with_hf(self, prompt: str, output_file: Path) -> str:
|
||||
"""Generate an image with HuggingFace's API.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (Path): The filename to save the image to
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
API_URL = f"https://api-inference.huggingface.co/models/{self.legacy_config.huggingface_image_model}" # noqa: E501
|
||||
if self.legacy_config.huggingface_api_token is None:
|
||||
raise ValueError(
|
||||
"You need to set your Hugging Face API token in the config file."
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.legacy_config.huggingface_api_token}",
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < 10:
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
headers=headers,
|
||||
json={
|
||||
"inputs": prompt,
|
||||
},
|
||||
)
|
||||
|
||||
if response.ok:
|
||||
try:
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
image.save(output_file)
|
||||
return f"Saved to disk: {output_file}"
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
else:
|
||||
try:
|
||||
error = json.loads(response.text)
|
||||
if "estimated_time" in error:
|
||||
delay = error["estimated_time"]
|
||||
logger.debug(response.text)
|
||||
logger.info("Retrying in", delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
|
||||
retry_count += 1
|
||||
|
||||
return "Error creating image."
|
||||
|
||||
def generate_image_with_dalle(
|
||||
self, prompt: str, output_file: Path, size: int
|
||||
) -> str:
|
||||
"""Generate an image with DALL-E.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (Path): The filename to save the image to
|
||||
size (int): The size of the image
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
|
||||
# Check for supported image sizes
|
||||
if size not in [256, 512, 1024]:
|
||||
closest = min([256, 512, 1024], key=lambda x: abs(x - size))
|
||||
logger.info(
|
||||
"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. "
|
||||
f"Setting to {closest}, was {size}."
|
||||
)
|
||||
size = closest
|
||||
|
||||
response = OpenAI(
|
||||
api_key=self.legacy_config.openai_credentials.api_key.get_secret_value()
|
||||
).images.generate(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size=f"{size}x{size}",
|
||||
response_format="b64_json",
|
||||
)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {agent.legacy_config.huggingface_api_token}",
|
||||
"X-Use-Cache": "false",
|
||||
}
|
||||
|
||||
retry_count = 0
|
||||
while retry_count < 10:
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
image_data = b64decode(response.data[0].b64_json)
|
||||
|
||||
with open(output_file, mode="wb") as png:
|
||||
png.write(image_data)
|
||||
|
||||
return f"Saved to disk: {output_file}"
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
self,
|
||||
prompt: str,
|
||||
output_file: Path,
|
||||
size: int = 512,
|
||||
negative_prompt: str = "",
|
||||
extra: dict = {},
|
||||
) -> str:
|
||||
"""Generate an image with Stable Diffusion webui.
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
|
||||
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if self.legacy_config.sd_webui_auth:
|
||||
username, password = self.legacy_config.sd_webui_auth.split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
API_URL,
|
||||
headers=headers,
|
||||
f"{self.legacy_config.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"inputs": prompt,
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_index": "DDIM",
|
||||
"steps": 20,
|
||||
"config_scale": 7.0,
|
||||
"width": size,
|
||||
"height": size,
|
||||
"n_iter": 1,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
if response.ok:
|
||||
try:
|
||||
image = Image.open(io.BytesIO(response.content))
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
image.save(output_file)
|
||||
return f"Saved to disk: {output_file}"
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
else:
|
||||
try:
|
||||
error = json.loads(response.text)
|
||||
if "estimated_time" in error:
|
||||
delay = error["estimated_time"]
|
||||
logger.debug(response.text)
|
||||
logger.info("Retrying in", delay)
|
||||
time.sleep(delay)
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
break
|
||||
logger.info(f"Image Generated for prompt: '{prompt}'")
|
||||
|
||||
retry_count += 1
|
||||
# Save the image to disk
|
||||
response = response.json()
|
||||
b64 = b64decode(response["images"][0].split(",", 1)[0])
|
||||
image = Image.open(io.BytesIO(b64))
|
||||
image.save(output_file)
|
||||
|
||||
return "Error creating image."
|
||||
|
||||
|
||||
def generate_image_with_dalle(
|
||||
prompt: str, output_file: Path, size: int, agent: Agent
|
||||
) -> str:
|
||||
"""Generate an image with DALL-E.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (Path): The filename to save the image to
|
||||
size (int): The size of the image
|
||||
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
|
||||
# Check for supported image sizes
|
||||
if size not in [256, 512, 1024]:
|
||||
closest = min([256, 512, 1024], key=lambda x: abs(x - size))
|
||||
logger.info(
|
||||
"DALL-E only supports image sizes of 256x256, 512x512, or 1024x1024. "
|
||||
f"Setting to {closest}, was {size}."
|
||||
)
|
||||
size = closest
|
||||
|
||||
response = OpenAI(
|
||||
api_key=agent.legacy_config.openai_credentials.api_key.get_secret_value()
|
||||
).images.generate(
|
||||
prompt=prompt,
|
||||
n=1,
|
||||
size=f"{size}x{size}",
|
||||
response_format="b64_json",
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt:{prompt}")
|
||||
|
||||
image_data = b64decode(response.data[0].b64_json)
|
||||
|
||||
with open(output_file, mode="wb") as png:
|
||||
png.write(image_data)
|
||||
|
||||
return f"Saved to disk: {output_file}"
|
||||
|
||||
|
||||
def generate_image_with_sd_webui(
|
||||
prompt: str,
|
||||
output_file: Path,
|
||||
agent: Agent,
|
||||
size: int = 512,
|
||||
negative_prompt: str = "",
|
||||
extra: dict = {},
|
||||
) -> str:
|
||||
"""Generate an image with Stable Diffusion webui.
|
||||
Args:
|
||||
prompt (str): The prompt to use
|
||||
filename (str): The filename to save the image to
|
||||
size (int, optional): The size of the image. Defaults to 256.
|
||||
negative_prompt (str, optional): The negative prompt to use. Defaults to "".
|
||||
extra (dict, optional): Extra parameters to pass to the API. Defaults to {}.
|
||||
Returns:
|
||||
str: The filename of the image
|
||||
"""
|
||||
# Create a session and set the basic auth if needed
|
||||
s = requests.Session()
|
||||
if agent.legacy_config.sd_webui_auth:
|
||||
username, password = agent.legacy_config.sd_webui_auth.split(":")
|
||||
s.auth = (username, password or "")
|
||||
|
||||
# Generate the images
|
||||
response = requests.post(
|
||||
f"{agent.legacy_config.sd_webui_url}/sdapi/v1/txt2img",
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_index": "DDIM",
|
||||
"steps": 20,
|
||||
"config_scale": 7.0,
|
||||
"width": size,
|
||||
"height": size,
|
||||
"n_iter": 1,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Image Generated for prompt: '{prompt}'")
|
||||
|
||||
# Save the image to disk
|
||||
response = response.json()
|
||||
b64 = b64decode(response["images"][0].split(",", 1)[0])
|
||||
image = Image.open(io.BytesIO(b64))
|
||||
image.save(output_file)
|
||||
|
||||
return f"Saved to disk: {output_file}"
|
||||
return f"Saved to disk: {output_file}"
|
||||
|
||||
@@ -1,69 +1,55 @@
|
||||
"""Commands to control the internal state of the program"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
from autogpt.agents.features.context import get_agent_context
|
||||
from autogpt.agents.utils.exceptions import AgentTerminated, InvalidArgumentError
|
||||
from autogpt.agents.protocols import CommandProvider, DirectiveProvider, MessageProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config.ai_profile import AIProfile
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.resource.model_providers.schema import ChatMessage
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
COMMAND_CATEGORY = "system"
|
||||
COMMAND_CATEGORY_TITLE = "System"
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.utils.exceptions import AgentFinished
|
||||
from autogpt.utils.utils import DEFAULT_FINISH_COMMAND
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"finish",
|
||||
"Use this to shut down once you have completed your task,"
|
||||
" or when there are insurmountable problems that make it impossible"
|
||||
" for you to finish your task.",
|
||||
{
|
||||
"reason": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="A summary to the user of how the goals were accomplished",
|
||||
required=True,
|
||||
class SystemComponent(DirectiveProvider, MessageProvider, CommandProvider):
|
||||
"""Component for system messages and commands."""
|
||||
|
||||
def __init__(self, config: Config, profile: AIProfile):
|
||||
self.legacy_config = config
|
||||
self.profile = profile
|
||||
|
||||
def get_constraints(self) -> Iterator[str]:
|
||||
if self.profile.api_budget > 0.0:
|
||||
yield (
|
||||
f"It takes money to let you run. "
|
||||
f"Your API budget is ${self.profile.api_budget:.3f}"
|
||||
)
|
||||
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
# Clock
|
||||
yield ChatMessage.system(
|
||||
f"## Clock\nThe current time and date is {time.strftime('%c')}"
|
||||
)
|
||||
},
|
||||
)
|
||||
def finish(reason: str, agent: Agent) -> None:
|
||||
"""
|
||||
A function that takes in a string and exits the program
|
||||
|
||||
Parameters:
|
||||
reason (str): A summary to the user of how the goals were accomplished.
|
||||
Returns:
|
||||
A result string from create chat completion. A list of suggestions to
|
||||
improve the code.
|
||||
"""
|
||||
raise AgentTerminated(reason)
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.finish
|
||||
|
||||
|
||||
@command(
|
||||
"hide_context_item",
|
||||
"Hide an open file, folder or other context item, to save memory.",
|
||||
{
|
||||
"number": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The 1-based index of the context item to hide",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
available=lambda a: bool(get_agent_context(a)),
|
||||
)
|
||||
def close_context_item(number: int, agent: Agent) -> str:
|
||||
assert (context := get_agent_context(agent)) is not None
|
||||
|
||||
if number > len(context.items) or number == 0:
|
||||
raise InvalidArgumentError(f"Index {number} out of range")
|
||||
|
||||
context.close(number)
|
||||
return f"Context item {number} hidden ✅"
|
||||
@command(
|
||||
names=[DEFAULT_FINISH_COMMAND],
|
||||
parameters={
|
||||
"reason": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="A summary to the user of how the goals were accomplished",
|
||||
required=True,
|
||||
),
|
||||
},
|
||||
)
|
||||
def finish(self, reason: str):
|
||||
"""Use this to shut down once you have completed your task,
|
||||
or when there are insurmountable problems that make it impossible
|
||||
for you to finish your task."""
|
||||
raise AgentFinished(reason)
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def get_datetime() -> str:
|
||||
"""Return the current date and time
|
||||
|
||||
Returns:
|
||||
str: The current date and time
|
||||
"""
|
||||
return "Current date and time: " + datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -1,32 +1,37 @@
|
||||
"""Commands to interact with the user"""
|
||||
from typing import Iterator
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.protocols import CommandProvider
|
||||
from autogpt.app.utils import clean_input
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
COMMAND_CATEGORY = "user_interaction"
|
||||
COMMAND_CATEGORY_TITLE = "User Interaction"
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.utils.utils import DEFAULT_ASK_COMMAND
|
||||
|
||||
|
||||
@command(
|
||||
"ask_user",
|
||||
(
|
||||
"If you need more details or information regarding the given goals,"
|
||||
" you can ask the user for input"
|
||||
),
|
||||
{
|
||||
"question": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The question or prompt to the user",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
enabled=lambda config: not config.noninteractive_mode,
|
||||
)
|
||||
async def ask_user(question: str, agent: Agent) -> str:
|
||||
print(f"\nQ: {question}")
|
||||
resp = await clean_input(agent.legacy_config, "A:")
|
||||
return f"The user's answer: '{resp}'"
|
||||
class UserInteractionComponent(CommandProvider):
|
||||
"""Provides commands to interact with the user."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
self.config = config
|
||||
self._enabled = not config.noninteractive_mode
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.ask_user
|
||||
|
||||
@command(
|
||||
names=[DEFAULT_ASK_COMMAND],
|
||||
parameters={
|
||||
"question": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The question or prompt to the user",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
)
|
||||
def ask_user(self, question: str) -> str:
|
||||
"""If you need more details or information regarding the given goals,
|
||||
you can ask the user for input."""
|
||||
print(f"\nQ: {question}")
|
||||
resp = clean_input(self.config, "A:")
|
||||
return f"The user's answer: '{resp}'"
|
||||
|
||||
@@ -1,171 +1,195 @@
|
||||
"""Commands to search the web with"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from itertools import islice
|
||||
from typing import Iterator
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.agents.utils.exceptions import ConfigurationError
|
||||
from autogpt.agents.protocols import CommandProvider, DirectiveProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
|
||||
COMMAND_CATEGORY = "web_search"
|
||||
COMMAND_CATEGORY_TITLE = "Web Search"
|
||||
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.utils.exceptions import ConfigurationError
|
||||
|
||||
DUCKDUCKGO_MAX_ATTEMPTS = 3
|
||||
|
||||
|
||||
@command(
|
||||
"web_search",
|
||||
"Searches the web",
|
||||
{
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
aliases=["search"],
|
||||
)
|
||||
def web_search(query: str, agent: Agent, num_results: int = 8) -> str:
|
||||
"""Return the results of a Google search
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
search_results = []
|
||||
attempts = 0
|
||||
|
||||
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
|
||||
if not query:
|
||||
return json.dumps(search_results)
|
||||
|
||||
results = DDGS().text(query)
|
||||
search_results = list(islice(results, num_results))
|
||||
|
||||
if search_results:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
attempts += 1
|
||||
|
||||
search_results = [
|
||||
{
|
||||
"title": r["title"],
|
||||
"url": r["href"],
|
||||
**({"exerpt": r["body"]} if r.get("body") else {}),
|
||||
}
|
||||
for r in search_results
|
||||
]
|
||||
|
||||
results = (
|
||||
"## Search results\n"
|
||||
# "Read these results carefully."
|
||||
# " Extract the information you need for your task from the list of results"
|
||||
# " if possible. Otherwise, choose a webpage from the list to read entirely."
|
||||
# "\n\n"
|
||||
) + "\n\n".join(
|
||||
f"### \"{r['title']}\"\n"
|
||||
f"**URL:** {r['url']} \n"
|
||||
"**Excerpt:** " + (f'"{exerpt}"' if (exerpt := r.get("exerpt")) else "N/A")
|
||||
for r in search_results
|
||||
)
|
||||
return safe_google_results(results)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@command(
|
||||
"google",
|
||||
"Google Search",
|
||||
{
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
)
|
||||
},
|
||||
lambda config: bool(config.google_api_key)
|
||||
and bool(config.google_custom_search_engine_id),
|
||||
"Configure google_api_key and custom_search_engine_id.",
|
||||
aliases=["search"],
|
||||
)
|
||||
def google(query: str, agent: Agent, num_results: int = 8) -> str | list[str]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
class WebSearchComponent(DirectiveProvider, CommandProvider):
|
||||
"""Provides commands to search the web."""
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
def __init__(self, config: Config):
|
||||
self.legacy_config = config
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
try:
|
||||
# Get the Google API key and Custom Search Engine ID from the config file
|
||||
api_key = agent.legacy_config.google_api_key
|
||||
custom_search_engine_id = agent.legacy_config.google_custom_search_engine_id
|
||||
|
||||
# Initialize the Custom Search API service
|
||||
service = build("customsearch", "v1", developerKey=api_key)
|
||||
|
||||
# Send the search query and retrieve the results
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
# Create a list of only the URLs from the search results
|
||||
search_results_links = [item["link"] for item in search_results]
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
error_details = json.loads(e.content.decode())
|
||||
|
||||
# Check if the error is related to an invalid or missing API key
|
||||
if error_details.get("error", {}).get(
|
||||
"code"
|
||||
) == 403 and "invalid API key" in error_details.get("error", {}).get(
|
||||
"message", ""
|
||||
if (
|
||||
not self.legacy_config.google_api_key
|
||||
or not self.legacy_config.google_custom_search_engine_id
|
||||
):
|
||||
raise ConfigurationError(
|
||||
"The provided Google API key is invalid or missing."
|
||||
logger.info(
|
||||
"Configure google_api_key and custom_search_engine_id "
|
||||
"to use Google API search."
|
||||
)
|
||||
raise
|
||||
# google_result can be a list or a string depending on the search results
|
||||
|
||||
# Return the list of search result URLs
|
||||
return safe_google_results(search_results_links)
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "Internet access for searches and information gathering."
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.web_search
|
||||
|
||||
def safe_google_results(results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a Google search in a safe format.
|
||||
if (
|
||||
self.legacy_config.google_api_key
|
||||
and self.legacy_config.google_custom_search_engine_id
|
||||
):
|
||||
yield self.google
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
@command(
|
||||
["web_search", "search"],
|
||||
"Searches the web",
|
||||
{
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
),
|
||||
"num_results": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The number of results to return",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def web_search(self, query: str, num_results: int = 8) -> str:
|
||||
"""Return the results of a Google search
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps(
|
||||
[result.encode("utf-8", "ignore").decode("utf-8") for result in results]
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
search_results = []
|
||||
attempts = 0
|
||||
|
||||
while attempts < DUCKDUCKGO_MAX_ATTEMPTS:
|
||||
if not query:
|
||||
return json.dumps(search_results)
|
||||
|
||||
search_results = DDGS().text(query, max_results=num_results)
|
||||
|
||||
if search_results:
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
attempts += 1
|
||||
|
||||
search_results = [
|
||||
{
|
||||
"title": r["title"],
|
||||
"url": r["href"],
|
||||
**({"exerpt": r["body"]} if r.get("body") else {}),
|
||||
}
|
||||
for r in search_results
|
||||
]
|
||||
|
||||
results = ("## Search results\n") + "\n\n".join(
|
||||
f"### \"{r['title']}\"\n"
|
||||
f"**URL:** {r['url']} \n"
|
||||
"**Excerpt:** " + (f'"{exerpt}"' if (exerpt := r.get("exerpt")) else "N/A")
|
||||
for r in search_results
|
||||
)
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
return self.safe_google_results(results)
|
||||
|
||||
@command(
|
||||
["google"],
|
||||
"Google Search",
|
||||
{
|
||||
"query": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The search query",
|
||||
required=True,
|
||||
),
|
||||
"num_results": JSONSchema(
|
||||
type=JSONSchema.Type.INTEGER,
|
||||
description="The number of results to return",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
def google(self, query: str, num_results: int = 8) -> str | list[str]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
try:
|
||||
# Get the Google API key and Custom Search Engine ID from the config file
|
||||
api_key = self.legacy_config.google_api_key
|
||||
custom_search_engine_id = self.legacy_config.google_custom_search_engine_id
|
||||
|
||||
# Initialize the Custom Search API service
|
||||
service = build("customsearch", "v1", developerKey=api_key)
|
||||
|
||||
# Send the search query and retrieve the results
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.execute()
|
||||
)
|
||||
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
# Create a list of only the URLs from the search results
|
||||
search_results_links = [item["link"] for item in search_results]
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
error_details = json.loads(e.content.decode())
|
||||
|
||||
# Check if the error is related to an invalid or missing API key
|
||||
if error_details.get("error", {}).get(
|
||||
"code"
|
||||
) == 403 and "invalid API key" in error_details.get("error", {}).get(
|
||||
"message", ""
|
||||
):
|
||||
raise ConfigurationError(
|
||||
"The provided Google API key is invalid or missing."
|
||||
)
|
||||
raise
|
||||
# google_result can be a list or a string depending on the search results
|
||||
|
||||
# Return the list of search result URLs
|
||||
return self.safe_google_results(search_results_links)
|
||||
|
||||
def safe_google_results(self, results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a Google search in a safe format.
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps(
|
||||
[result.encode("utf-8", "ignore").decode("utf-8") for result in results]
|
||||
)
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""Commands for browsing a website"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from sys import platform
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
from typing import Iterator, Type
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
@@ -30,26 +28,24 @@ from webdriver_manager.chrome import ChromeDriverManager
|
||||
from webdriver_manager.firefox import GeckoDriverManager
|
||||
from webdriver_manager.microsoft import EdgeChromiumDriverManager as EdgeDriverManager
|
||||
|
||||
from autogpt.agents.utils.exceptions import CommandExecutionError
|
||||
from autogpt.agents.protocols import CommandProvider, DirectiveProvider
|
||||
from autogpt.command_decorator import command
|
||||
from autogpt.config import Config
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.models.command import Command
|
||||
from autogpt.processing.html import extract_hyperlinks, format_hyperlinks
|
||||
from autogpt.processing.text import summarize_text
|
||||
from autogpt.processing.text import extract_information, summarize_text
|
||||
from autogpt.url_utils.validators import validate_url
|
||||
|
||||
COMMAND_CATEGORY = "web_browse"
|
||||
COMMAND_CATEGORY_TITLE = "Web Browsing"
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.agents.agent import Agent
|
||||
from autogpt.config import Config
|
||||
|
||||
from autogpt.utils.exceptions import CommandExecutionError, TooMuchOutputError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FILE_DIR = Path(__file__).parent.parent
|
||||
TOKENS_TO_TRIGGER_SUMMARY = 50
|
||||
MAX_RAW_CONTENT_LENGTH = 500
|
||||
LINKS_TO_RETURN = 20
|
||||
|
||||
|
||||
@@ -57,248 +53,324 @@ class BrowsingError(CommandExecutionError):
|
||||
"""An error occurred while trying to browse the page"""
|
||||
|
||||
|
||||
@command(
|
||||
"read_webpage",
|
||||
(
|
||||
"Read a webpage, and extract specific information from it"
|
||||
" if a question is specified."
|
||||
" If you are looking to extract specific information from the webpage,"
|
||||
" you should specify a question."
|
||||
),
|
||||
{
|
||||
"url": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The URL to visit",
|
||||
required=True,
|
||||
class WebSeleniumComponent(DirectiveProvider, CommandProvider):
|
||||
"""Provides commands to browse the web using Selenium."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
model_info: ChatModelInfo,
|
||||
):
|
||||
self.legacy_config = config
|
||||
self.llm_provider = llm_provider
|
||||
self.model_info = model_info
|
||||
|
||||
def get_resources(self) -> Iterator[str]:
|
||||
yield "Ability to read websites."
|
||||
|
||||
def get_commands(self) -> Iterator[Command]:
|
||||
yield self.read_webpage
|
||||
|
||||
@command(
|
||||
["read_webpage"],
|
||||
(
|
||||
"Read a webpage, and extract specific information from it."
|
||||
" You must specify either topics_of_interest,"
|
||||
" a question, or get_raw_content."
|
||||
),
|
||||
"question": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description=(
|
||||
"A question that you want to answer using the content of the webpage."
|
||||
{
|
||||
"url": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description="The URL to visit",
|
||||
required=True,
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
@validate_url
|
||||
async def read_webpage(url: str, agent: Agent, question: str = "") -> str:
|
||||
"""Browse a website and return the answer and links to the user
|
||||
"topics_of_interest": JSONSchema(
|
||||
type=JSONSchema.Type.ARRAY,
|
||||
items=JSONSchema(type=JSONSchema.Type.STRING),
|
||||
description=(
|
||||
"A list of topics about which you want to extract information "
|
||||
"from the page."
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
"question": JSONSchema(
|
||||
type=JSONSchema.Type.STRING,
|
||||
description=(
|
||||
"A question you want to answer using the content of the webpage."
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
"get_raw_content": JSONSchema(
|
||||
type=JSONSchema.Type.BOOLEAN,
|
||||
description=(
|
||||
"If true, the unprocessed content of the webpage will be returned. "
|
||||
"This consumes a lot of tokens, so use it with caution."
|
||||
),
|
||||
required=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
@validate_url
|
||||
async def read_webpage(
|
||||
self,
|
||||
url: str,
|
||||
*,
|
||||
topics_of_interest: list[str] = [],
|
||||
get_raw_content: bool = False,
|
||||
question: str = "",
|
||||
) -> str:
|
||||
"""Browse a website and return the answer and links to the user
|
||||
|
||||
Args:
|
||||
url (str): The url of the website to browse
|
||||
question (str): The question to answer using the content of the webpage
|
||||
Args:
|
||||
url (str): The url of the website to browse
|
||||
question (str): The question to answer using the content of the webpage
|
||||
|
||||
Returns:
|
||||
str: The answer and links to the user and the webdriver
|
||||
"""
|
||||
driver = None
|
||||
try:
|
||||
# FIXME: agent.config -> something else
|
||||
driver = open_page_in_browser(url, agent.legacy_config)
|
||||
Returns:
|
||||
str: The answer and links to the user and the webdriver
|
||||
"""
|
||||
driver = None
|
||||
try:
|
||||
driver = await self.open_page_in_browser(url, self.legacy_config)
|
||||
|
||||
text = scrape_text_with_selenium(driver)
|
||||
links = scrape_links_with_selenium(driver, url)
|
||||
text = self.scrape_text_with_selenium(driver)
|
||||
links = self.scrape_links_with_selenium(driver, url)
|
||||
|
||||
return_literal_content = True
|
||||
summarized = False
|
||||
return_literal_content = True
|
||||
summarized = False
|
||||
if not text:
|
||||
return f"Website did not contain any text.\n\nLinks: {links}"
|
||||
elif get_raw_content:
|
||||
if (
|
||||
output_tokens := self.llm_provider.count_tokens(
|
||||
text, self.model_info.name
|
||||
)
|
||||
) > MAX_RAW_CONTENT_LENGTH:
|
||||
oversize_factor = round(output_tokens / MAX_RAW_CONTENT_LENGTH, 1)
|
||||
raise TooMuchOutputError(
|
||||
f"Page content is {oversize_factor}x the allowed length "
|
||||
"for `get_raw_content=true`"
|
||||
)
|
||||
return text + (f"\n\nLinks: {links}" if links else "")
|
||||
else:
|
||||
text = await self.summarize_webpage(
|
||||
text, question or None, topics_of_interest
|
||||
)
|
||||
return_literal_content = bool(question)
|
||||
summarized = True
|
||||
|
||||
# Limit links to LINKS_TO_RETURN
|
||||
if len(links) > LINKS_TO_RETURN:
|
||||
links = links[:LINKS_TO_RETURN]
|
||||
|
||||
text_fmt = f"'''{text}'''" if "\n" in text else f"'{text}'"
|
||||
links_fmt = "\n".join(f"- {link}" for link in links)
|
||||
return (
|
||||
f"Page content{' (summary)' if summarized else ''}:"
|
||||
if return_literal_content
|
||||
else "Answer gathered from webpage:"
|
||||
) + f" {text_fmt}\n\nLinks:\n{links_fmt}"
|
||||
|
||||
except WebDriverException as e:
|
||||
# These errors are often quite long and include lots of context.
|
||||
# Just grab the first line.
|
||||
msg = e.msg.split("\n")[0] if e.msg else str(e)
|
||||
if "net::" in msg:
|
||||
raise BrowsingError(
|
||||
"A networking error occurred while trying to load the page: %s"
|
||||
% re.sub(r"^unknown error: ", "", msg)
|
||||
)
|
||||
raise CommandExecutionError(msg)
|
||||
finally:
|
||||
if driver:
|
||||
driver.close()
|
||||
|
||||
def scrape_text_with_selenium(self, driver: WebDriver) -> str:
|
||||
"""Scrape text from a browser window using selenium
|
||||
|
||||
Args:
|
||||
driver (WebDriver): A driver object representing
|
||||
the browser window to scrape
|
||||
|
||||
Returns:
|
||||
str: the text scraped from the website
|
||||
"""
|
||||
|
||||
# Get the HTML content directly from the browser's DOM
|
||||
page_source = driver.execute_script("return document.body.outerHTML;")
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = "\n".join(chunk for chunk in chunks if chunk)
|
||||
return text
|
||||
|
||||
def scrape_links_with_selenium(self, driver: WebDriver, base_url: str) -> list[str]:
|
||||
"""Scrape links from a website using selenium
|
||||
|
||||
Args:
|
||||
driver (WebDriver): A driver object representing
|
||||
the browser window to scrape
|
||||
base_url (str): The base URL to use for resolving relative links
|
||||
|
||||
Returns:
|
||||
List[str]: The links scraped from the website
|
||||
"""
|
||||
page_source = driver.page_source
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
hyperlinks = extract_hyperlinks(soup, base_url)
|
||||
|
||||
return format_hyperlinks(hyperlinks)
|
||||
|
||||
async def open_page_in_browser(self, url: str, config: Config) -> WebDriver:
|
||||
"""Open a browser window and load a web page using Selenium
|
||||
|
||||
Params:
|
||||
url (str): The URL of the page to load
|
||||
config (Config): The applicable application configuration
|
||||
|
||||
Returns:
|
||||
driver (WebDriver): A driver object representing
|
||||
the browser window to scrape
|
||||
"""
|
||||
logging.getLogger("selenium").setLevel(logging.CRITICAL)
|
||||
|
||||
options_available: dict[str, Type[BrowserOptions]] = {
|
||||
"chrome": ChromeOptions,
|
||||
"edge": EdgeOptions,
|
||||
"firefox": FirefoxOptions,
|
||||
"safari": SafariOptions,
|
||||
}
|
||||
|
||||
options: BrowserOptions = options_available[config.selenium_web_browser]()
|
||||
options.add_argument(f"user-agent={config.user_agent}")
|
||||
|
||||
if isinstance(options, FirefoxOptions):
|
||||
if config.selenium_headless:
|
||||
options.headless = True
|
||||
options.add_argument("--disable-gpu")
|
||||
driver = FirefoxDriver(
|
||||
service=GeckoDriverService(GeckoDriverManager().install()),
|
||||
options=options,
|
||||
)
|
||||
elif isinstance(options, EdgeOptions):
|
||||
driver = EdgeDriver(
|
||||
service=EdgeDriverService(EdgeDriverManager().install()),
|
||||
options=options,
|
||||
)
|
||||
elif isinstance(options, SafariOptions):
|
||||
# Requires a bit more setup on the users end.
|
||||
# See https://developer.apple.com/documentation/webkit/testing_with_webdriver_in_safari # noqa: E501
|
||||
driver = SafariDriver(options=options)
|
||||
elif isinstance(options, ChromeOptions):
|
||||
if platform == "linux" or platform == "linux2":
|
||||
options.add_argument("--disable-dev-shm-usage")
|
||||
options.add_argument("--remote-debugging-port=9222")
|
||||
|
||||
options.add_argument("--no-sandbox")
|
||||
if config.selenium_headless:
|
||||
options.add_argument("--headless=new")
|
||||
options.add_argument("--disable-gpu")
|
||||
|
||||
self._sideload_chrome_extensions(
|
||||
options, config.app_data_dir / "assets" / "crx"
|
||||
)
|
||||
|
||||
if (chromium_driver_path := Path("/usr/bin/chromedriver")).exists():
|
||||
chrome_service = ChromeDriverService(str(chromium_driver_path))
|
||||
else:
|
||||
try:
|
||||
chrome_driver = ChromeDriverManager().install()
|
||||
except AttributeError as e:
|
||||
if "'NoneType' object has no attribute 'split'" in str(e):
|
||||
# https://github.com/SergeyPirogov/webdriver_manager/issues/649
|
||||
logger.critical(
|
||||
"Connecting to browser failed:"
|
||||
" is Chrome or Chromium installed?"
|
||||
)
|
||||
raise
|
||||
chrome_service = ChromeDriverService(chrome_driver)
|
||||
driver = ChromeDriver(service=chrome_service, options=options)
|
||||
|
||||
driver.get(url)
|
||||
|
||||
# Wait for page to be ready, sleep 2 seconds, wait again until page ready.
|
||||
# This allows the cookiewall squasher time to get rid of cookie walls.
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
return driver
|
||||
|
||||
def _sideload_chrome_extensions(
|
||||
self, options: ChromeOptions, dl_folder: Path
|
||||
) -> None:
|
||||
crx_download_url_template = "https://clients2.google.com/service/update2/crx?response=redirect&prodversion=49.0&acceptformat=crx3&x=id%3D{crx_id}%26installsource%3Dondemand%26uc" # noqa
|
||||
cookiewall_squasher_crx_id = "edibdbjcniadpccecjdfdjjppcpchdlm"
|
||||
adblocker_crx_id = "cjpalhdlnbpafiamejdnhcphjbkeiagm"
|
||||
|
||||
# Make sure the target folder exists
|
||||
dl_folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for crx_id in (cookiewall_squasher_crx_id, adblocker_crx_id):
|
||||
crx_path = dl_folder / f"{crx_id}.crx"
|
||||
if not crx_path.exists():
|
||||
logger.debug(f"Downloading CRX {crx_id}...")
|
||||
crx_download_url = crx_download_url_template.format(crx_id=crx_id)
|
||||
urlretrieve(crx_download_url, crx_path)
|
||||
logger.debug(f"Downloaded {crx_path.name}")
|
||||
options.add_extension(str(crx_path))
|
||||
|
||||
async def summarize_webpage(
|
||||
self,
|
||||
text: str,
|
||||
question: str | None,
|
||||
topics_of_interest: list[str],
|
||||
) -> str:
|
||||
"""Summarize text using the OpenAI API
|
||||
|
||||
Args:
|
||||
url (str): The url of the text
|
||||
text (str): The text to summarize
|
||||
question (str): The question to ask the model
|
||||
driver (WebDriver): The webdriver to use to scroll the page
|
||||
|
||||
Returns:
|
||||
str: The summary of the text
|
||||
"""
|
||||
if not text:
|
||||
return f"Website did not contain any text.\n\nLinks: {links}"
|
||||
elif (
|
||||
agent.llm_provider.count_tokens(text, agent.llm.name)
|
||||
> TOKENS_TO_TRIGGER_SUMMARY
|
||||
):
|
||||
text = await summarize_memorize_webpage(
|
||||
url, text, question or None, agent, driver
|
||||
raise ValueError("No text to summarize")
|
||||
|
||||
text_length = len(text)
|
||||
logger.debug(f"Web page content length: {text_length} characters")
|
||||
|
||||
result = None
|
||||
information = None
|
||||
if topics_of_interest:
|
||||
information = await extract_information(
|
||||
text,
|
||||
topics_of_interest=topics_of_interest,
|
||||
llm_provider=self.llm_provider,
|
||||
config=self.legacy_config,
|
||||
)
|
||||
return_literal_content = bool(question)
|
||||
summarized = True
|
||||
|
||||
# Limit links to LINKS_TO_RETURN
|
||||
if len(links) > LINKS_TO_RETURN:
|
||||
links = links[:LINKS_TO_RETURN]
|
||||
|
||||
text_fmt = f"'''{text}'''" if "\n" in text else f"'{text}'"
|
||||
links_fmt = "\n".join(f"- {link}" for link in links)
|
||||
return (
|
||||
f"Page content{' (summary)' if summarized else ''}:"
|
||||
if return_literal_content
|
||||
else "Answer gathered from webpage:"
|
||||
) + f" {text_fmt}\n\nLinks:\n{links_fmt}"
|
||||
|
||||
except WebDriverException as e:
|
||||
# These errors are often quite long and include lots of context.
|
||||
# Just grab the first line.
|
||||
msg = e.msg.split("\n")[0]
|
||||
if "net::" in msg:
|
||||
raise BrowsingError(
|
||||
"A networking error occurred while trying to load the page: %s"
|
||||
% re.sub(r"^unknown error: ", "", msg)
|
||||
return "\n".join(f"* {i}" for i in information)
|
||||
else:
|
||||
result, _ = await summarize_text(
|
||||
text,
|
||||
question=question,
|
||||
llm_provider=self.llm_provider,
|
||||
config=self.legacy_config,
|
||||
)
|
||||
raise CommandExecutionError(msg)
|
||||
finally:
|
||||
if driver:
|
||||
close_browser(driver)
|
||||
|
||||
|
||||
def scrape_text_with_selenium(driver: WebDriver) -> str:
|
||||
"""Scrape text from a browser window using selenium
|
||||
|
||||
Args:
|
||||
driver (WebDriver): A driver object representing the browser window to scrape
|
||||
|
||||
Returns:
|
||||
str: the text scraped from the website
|
||||
"""
|
||||
|
||||
# Get the HTML content directly from the browser's DOM
|
||||
page_source = driver.execute_script("return document.body.outerHTML;")
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
text = soup.get_text()
|
||||
lines = (line.strip() for line in text.splitlines())
|
||||
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
||||
text = "\n".join(chunk for chunk in chunks if chunk)
|
||||
return text
|
||||
|
||||
|
||||
def scrape_links_with_selenium(driver: WebDriver, base_url: str) -> list[str]:
|
||||
"""Scrape links from a website using selenium
|
||||
|
||||
Args:
|
||||
driver (WebDriver): A driver object representing the browser window to scrape
|
||||
base_url (str): The base URL to use for resolving relative links
|
||||
|
||||
Returns:
|
||||
List[str]: The links scraped from the website
|
||||
"""
|
||||
page_source = driver.page_source
|
||||
soup = BeautifulSoup(page_source, "html.parser")
|
||||
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
|
||||
hyperlinks = extract_hyperlinks(soup, base_url)
|
||||
|
||||
return format_hyperlinks(hyperlinks)
|
||||
|
||||
|
||||
def open_page_in_browser(url: str, config: Config) -> WebDriver:
|
||||
"""Open a browser window and load a web page using Selenium
|
||||
|
||||
Params:
|
||||
url (str): The URL of the page to load
|
||||
config (Config): The applicable application configuration
|
||||
|
||||
Returns:
|
||||
driver (WebDriver): A driver object representing the browser window to scrape
|
||||
"""
|
||||
logging.getLogger("selenium").setLevel(logging.CRITICAL)
|
||||
|
||||
options_available: dict[str, Type[BrowserOptions]] = {
|
||||
"chrome": ChromeOptions,
|
||||
"edge": EdgeOptions,
|
||||
"firefox": FirefoxOptions,
|
||||
"safari": SafariOptions,
|
||||
}
|
||||
|
||||
options: BrowserOptions = options_available[config.selenium_web_browser]()
|
||||
options.add_argument(f"user-agent={config.user_agent}")
|
||||
|
||||
if config.selenium_web_browser == "firefox":
|
||||
if config.selenium_headless:
|
||||
options.headless = True
|
||||
options.add_argument("--disable-gpu")
|
||||
driver = FirefoxDriver(
|
||||
service=GeckoDriverService(GeckoDriverManager().install()), options=options
|
||||
)
|
||||
elif config.selenium_web_browser == "edge":
|
||||
driver = EdgeDriver(
|
||||
service=EdgeDriverService(EdgeDriverManager().install()), options=options
|
||||
)
|
||||
elif config.selenium_web_browser == "safari":
|
||||
# Requires a bit more setup on the users end.
|
||||
# See https://developer.apple.com/documentation/webkit/testing_with_webdriver_in_safari # noqa: E501
|
||||
driver = SafariDriver(options=options)
|
||||
else:
|
||||
if platform == "linux" or platform == "linux2":
|
||||
options.add_argument("--disable-dev-shm-usage")
|
||||
options.add_argument("--remote-debugging-port=9222")
|
||||
|
||||
options.add_argument("--no-sandbox")
|
||||
if config.selenium_headless:
|
||||
options.add_argument("--headless=new")
|
||||
options.add_argument("--disable-gpu")
|
||||
|
||||
chromium_driver_path = Path("/usr/bin/chromedriver")
|
||||
|
||||
driver = ChromeDriver(
|
||||
service=ChromeDriverService(str(chromium_driver_path))
|
||||
if chromium_driver_path.exists()
|
||||
else ChromeDriverService(ChromeDriverManager().install()),
|
||||
options=options,
|
||||
)
|
||||
driver.get(url)
|
||||
|
||||
WebDriverWait(driver, 10).until(
|
||||
EC.presence_of_element_located((By.TAG_NAME, "body"))
|
||||
)
|
||||
|
||||
return driver
|
||||
|
||||
|
||||
def close_browser(driver: WebDriver) -> None:
|
||||
"""Close the browser
|
||||
|
||||
Args:
|
||||
driver (WebDriver): The webdriver to close
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
driver.quit()
|
||||
|
||||
|
||||
async def summarize_memorize_webpage(
|
||||
url: str,
|
||||
text: str,
|
||||
question: str | None,
|
||||
agent: Agent,
|
||||
driver: Optional[WebDriver] = None,
|
||||
) -> str:
|
||||
"""Summarize text using the OpenAI API
|
||||
|
||||
Args:
|
||||
url (str): The url of the text
|
||||
text (str): The text to summarize
|
||||
question (str): The question to ask the model
|
||||
driver (WebDriver): The webdriver to use to scroll the page
|
||||
|
||||
Returns:
|
||||
str: The summary of the text
|
||||
"""
|
||||
if not text:
|
||||
raise ValueError("No text to summarize")
|
||||
|
||||
text_length = len(text)
|
||||
logger.info(f"Text length: {text_length} characters")
|
||||
|
||||
# memory = get_memory(agent.legacy_config)
|
||||
|
||||
# new_memory = MemoryItem.from_webpage(
|
||||
# content=text,
|
||||
# url=url,
|
||||
# config=agent.legacy_config,
|
||||
# question=question,
|
||||
# )
|
||||
# memory.add(new_memory)
|
||||
|
||||
summary, _ = await summarize_text(
|
||||
text,
|
||||
question=question,
|
||||
llm_provider=agent.llm_provider,
|
||||
config=agent.legacy_config, # FIXME
|
||||
)
|
||||
return summary
|
||||
return result
|
||||
|
||||
82
autogpts/autogpt/autogpt/components/event_history.py
Normal file
82
autogpts/autogpt/autogpt/components/event_history.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from typing import Callable, Generic, Iterator, Optional
|
||||
|
||||
from autogpt.agents.features.watchdog import WatchdogComponent
|
||||
from autogpt.agents.protocols import AfterExecute, AfterParse, MessageProvider
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.resource.model_providers.schema import ChatMessage, ChatModelProvider
|
||||
from autogpt.models.action_history import (
|
||||
AP,
|
||||
ActionResult,
|
||||
Episode,
|
||||
EpisodicActionHistory,
|
||||
)
|
||||
from autogpt.prompts.utils import indent
|
||||
|
||||
|
||||
class EventHistoryComponent(MessageProvider, AfterParse, AfterExecute, Generic[AP]):
|
||||
"""Keeps track of the event history and provides a summary of the steps."""
|
||||
|
||||
run_after = [WatchdogComponent]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_history: EpisodicActionHistory[AP],
|
||||
max_tokens: int,
|
||||
count_tokens: Callable[[str], int],
|
||||
legacy_config: Config,
|
||||
llm_provider: ChatModelProvider,
|
||||
) -> None:
|
||||
self.event_history = event_history
|
||||
self.max_tokens = max_tokens
|
||||
self.count_tokens = count_tokens
|
||||
self.legacy_config = legacy_config
|
||||
self.llm_provider = llm_provider
|
||||
|
||||
def get_messages(self) -> Iterator[ChatMessage]:
|
||||
if progress := self._compile_progress(
|
||||
self.event_history.episodes,
|
||||
self.max_tokens,
|
||||
self.count_tokens,
|
||||
):
|
||||
yield ChatMessage.system(f"## Progress on your Task so far\n\n{progress}")
|
||||
|
||||
def after_parse(self, result: AP) -> None:
|
||||
self.event_history.register_action(result)
|
||||
|
||||
async def after_execute(self, result: ActionResult) -> None:
|
||||
self.event_history.register_result(result)
|
||||
await self.event_history.handle_compression(
|
||||
self.llm_provider, self.legacy_config
|
||||
)
|
||||
|
||||
def _compile_progress(
|
||||
self,
|
||||
episode_history: list[Episode],
|
||||
max_tokens: Optional[int] = None,
|
||||
count_tokens: Optional[Callable[[str], int]] = None,
|
||||
) -> str:
|
||||
if max_tokens and not count_tokens:
|
||||
raise ValueError("count_tokens is required if max_tokens is set")
|
||||
|
||||
steps: list[str] = []
|
||||
tokens: int = 0
|
||||
n_episodes = len(episode_history)
|
||||
|
||||
for i, episode in enumerate(reversed(episode_history)):
|
||||
# Use full format for the latest 4 steps, summary or format for older steps
|
||||
if i < 4 or episode.summary is None:
|
||||
step_content = indent(episode.format(), 2).strip()
|
||||
else:
|
||||
step_content = episode.summary
|
||||
|
||||
step = f"* Step {n_episodes - i}: {step_content}"
|
||||
|
||||
if max_tokens and count_tokens:
|
||||
step_tokens = count_tokens(step)
|
||||
if tokens + step_tokens > max_tokens:
|
||||
break
|
||||
tokens += step_tokens
|
||||
|
||||
steps.insert(0, step)
|
||||
|
||||
return "\n\n".join(steps)
|
||||
@@ -5,7 +5,7 @@ import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from autogpt.logs.helpers import request_user_double_check
|
||||
from autogpt.utils import validate_yaml_file
|
||||
from autogpt.utils.utils import validate_yaml_file
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,7 +32,7 @@ class AIDirectives(BaseModel):
|
||||
raise RuntimeError(f"File validation failed: {message}")
|
||||
|
||||
with open(prompt_settings_file, encoding="utf-8") as file:
|
||||
config_params = yaml.load(file, Loader=yaml.FullLoader)
|
||||
config_params = yaml.load(file, Loader=yaml.SafeLoader)
|
||||
|
||||
return AIDirectives(
|
||||
constraints=config_params.get("constraints", []),
|
||||
|
||||
@@ -35,7 +35,7 @@ class AIProfile(BaseModel):
|
||||
|
||||
try:
|
||||
with open(ai_settings_file, encoding="utf-8") as file:
|
||||
config_params = yaml.load(file, Loader=yaml.FullLoader) or {}
|
||||
config_params = yaml.load(file, Loader=yaml.SafeLoader) or {}
|
||||
except FileNotFoundError:
|
||||
config_params = {}
|
||||
|
||||
|
||||
@@ -1,38 +1,40 @@
|
||||
"""Configuration class to store the state of bools for different scripts access."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from colorama import Fore
|
||||
from pydantic import Field, SecretStr, validator
|
||||
from pydantic import SecretStr, validator
|
||||
|
||||
import autogpt
|
||||
from autogpt.app.utils import clean_input
|
||||
from autogpt.core.configuration.schema import (
|
||||
Configurable,
|
||||
SystemSettings,
|
||||
UserConfigurable,
|
||||
)
|
||||
from autogpt.core.resource.model_providers import CHAT_MODELS, ModelName
|
||||
from autogpt.core.resource.model_providers.openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OpenAICredentials,
|
||||
OpenAIModelName,
|
||||
)
|
||||
from autogpt.file_workspace import FileWorkspaceBackendName
|
||||
from autogpt.file_storage import FileStorageBackendName
|
||||
from autogpt.logs.config import LoggingConfig
|
||||
from autogpt.plugins.plugins_config import PluginsConfig
|
||||
from autogpt.speech import TTSConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PROJECT_ROOT = Path(autogpt.__file__).parent.parent
|
||||
AI_SETTINGS_FILE = Path("ai_settings.yaml")
|
||||
AZURE_CONFIG_FILE = Path("azure.yaml")
|
||||
PLUGINS_CONFIG_FILE = Path("plugins_config.yaml")
|
||||
PROMPT_SETTINGS_FILE = Path("prompt_settings.yaml")
|
||||
|
||||
GPT_4_MODEL = "gpt-4"
|
||||
GPT_3_MODEL = "gpt-3.5-turbo"
|
||||
GPT_4_MODEL = OpenAIModelName.GPT4
|
||||
GPT_3_MODEL = OpenAIModelName.GPT3
|
||||
|
||||
|
||||
class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
@@ -49,20 +51,14 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
authorise_key: str = UserConfigurable(default="y", from_env="AUTHORISE_COMMAND_KEY")
|
||||
exit_key: str = UserConfigurable(default="n", from_env="EXIT_KEY")
|
||||
noninteractive_mode: bool = False
|
||||
chat_messages_enabled: bool = UserConfigurable(
|
||||
default=True, from_env=lambda: os.getenv("CHAT_MESSAGES_ENABLED") == "True"
|
||||
)
|
||||
|
||||
# TTS configuration
|
||||
tts_config: TTSConfig = TTSConfig()
|
||||
logging: LoggingConfig = LoggingConfig()
|
||||
tts_config: TTSConfig = TTSConfig()
|
||||
|
||||
# Workspace
|
||||
workspace_backend: FileWorkspaceBackendName = UserConfigurable(
|
||||
default=FileWorkspaceBackendName.LOCAL,
|
||||
from_env=lambda: FileWorkspaceBackendName(v)
|
||||
if (v := os.getenv("WORKSPACE_BACKEND"))
|
||||
else None,
|
||||
# File storage
|
||||
file_storage_backend: FileStorageBackendName = UserConfigurable(
|
||||
default=FileStorageBackendName.LOCAL, from_env="FILE_STORAGE_BACKEND"
|
||||
)
|
||||
|
||||
##########################
|
||||
@@ -70,32 +66,28 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
##########################
|
||||
# Paths
|
||||
ai_settings_file: Path = UserConfigurable(
|
||||
default=AI_SETTINGS_FILE,
|
||||
from_env=lambda: Path(f) if (f := os.getenv("AI_SETTINGS_FILE")) else None,
|
||||
default=AI_SETTINGS_FILE, from_env="AI_SETTINGS_FILE"
|
||||
)
|
||||
prompt_settings_file: Path = UserConfigurable(
|
||||
default=PROMPT_SETTINGS_FILE,
|
||||
from_env=lambda: Path(f) if (f := os.getenv("PROMPT_SETTINGS_FILE")) else None,
|
||||
from_env="PROMPT_SETTINGS_FILE",
|
||||
)
|
||||
|
||||
# Model configuration
|
||||
fast_llm: str = UserConfigurable(
|
||||
default="gpt-3.5-turbo-16k",
|
||||
from_env=lambda: os.getenv("FAST_LLM"),
|
||||
fast_llm: ModelName = UserConfigurable(
|
||||
default=OpenAIModelName.GPT3,
|
||||
from_env="FAST_LLM",
|
||||
)
|
||||
smart_llm: str = UserConfigurable(
|
||||
default="gpt-4",
|
||||
from_env=lambda: os.getenv("SMART_LLM"),
|
||||
)
|
||||
temperature: float = UserConfigurable(
|
||||
default=0,
|
||||
from_env=lambda: float(v) if (v := os.getenv("TEMPERATURE")) else None,
|
||||
smart_llm: ModelName = UserConfigurable(
|
||||
default=OpenAIModelName.GPT4_TURBO,
|
||||
from_env="SMART_LLM",
|
||||
)
|
||||
temperature: float = UserConfigurable(default=0, from_env="TEMPERATURE")
|
||||
openai_functions: bool = UserConfigurable(
|
||||
default=False, from_env=lambda: os.getenv("OPENAI_FUNCTIONS", "False") == "True"
|
||||
)
|
||||
embedding_model: str = UserConfigurable(
|
||||
default="text-embedding-ada-002", from_env="EMBEDDING_MODEL"
|
||||
default="text-embedding-3-small", from_env="EMBEDDING_MODEL"
|
||||
)
|
||||
browse_spacy_language_model: str = UserConfigurable(
|
||||
default="en_core_web_sm", from_env="BROWSE_SPACY_LANGUAGE_MODEL"
|
||||
@@ -111,10 +103,7 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
memory_backend: str = UserConfigurable("json_file", from_env="MEMORY_BACKEND")
|
||||
memory_index: str = UserConfigurable("auto-gpt-memory", from_env="MEMORY_INDEX")
|
||||
redis_host: str = UserConfigurable("localhost", from_env="REDIS_HOST")
|
||||
redis_port: int = UserConfigurable(
|
||||
default=6379,
|
||||
from_env=lambda: int(v) if (v := os.getenv("REDIS_PORT")) else None,
|
||||
)
|
||||
redis_port: int = UserConfigurable(default=6379, from_env="REDIS_PORT")
|
||||
redis_password: str = UserConfigurable("", from_env="REDIS_PASSWORD")
|
||||
wipe_redis_on_start: bool = UserConfigurable(
|
||||
default=True,
|
||||
@@ -125,9 +114,9 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
# Commands #
|
||||
############
|
||||
# General
|
||||
disabled_command_categories: list[str] = UserConfigurable(
|
||||
disabled_commands: list[str] = UserConfigurable(
|
||||
default_factory=list,
|
||||
from_env=lambda: _safe_split(os.getenv("DISABLED_COMMAND_CATEGORIES")),
|
||||
from_env=lambda: _safe_split(os.getenv("DISABLED_COMMANDS")),
|
||||
)
|
||||
|
||||
# File ops
|
||||
@@ -166,10 +155,7 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
sd_webui_url: Optional[str] = UserConfigurable(
|
||||
default="http://localhost:7860", from_env="SD_WEBUI_URL"
|
||||
)
|
||||
image_size: int = UserConfigurable(
|
||||
default=256,
|
||||
from_env=lambda: int(v) if (v := os.getenv("IMAGE_SIZE")) else None,
|
||||
)
|
||||
image_size: int = UserConfigurable(default=256, from_env="IMAGE_SIZE")
|
||||
|
||||
# Audio to text
|
||||
audio_to_text_provider: str = UserConfigurable(
|
||||
@@ -189,38 +175,13 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
from_env="USER_AGENT",
|
||||
)
|
||||
|
||||
###################
|
||||
# Plugin Settings #
|
||||
###################
|
||||
plugins_dir: str = UserConfigurable("plugins", from_env="PLUGINS_DIR")
|
||||
plugins_config_file: Path = UserConfigurable(
|
||||
default=PLUGINS_CONFIG_FILE,
|
||||
from_env=lambda: Path(f) if (f := os.getenv("PLUGINS_CONFIG_FILE")) else None,
|
||||
)
|
||||
plugins_config: PluginsConfig = Field(
|
||||
default_factory=lambda: PluginsConfig(plugins={})
|
||||
)
|
||||
plugins: list[AutoGPTPluginTemplate] = Field(default_factory=list, exclude=True)
|
||||
plugins_allowlist: list[str] = UserConfigurable(
|
||||
default_factory=list,
|
||||
from_env=lambda: _safe_split(os.getenv("ALLOWLISTED_PLUGINS")),
|
||||
)
|
||||
plugins_denylist: list[str] = UserConfigurable(
|
||||
default_factory=list,
|
||||
from_env=lambda: _safe_split(os.getenv("DENYLISTED_PLUGINS")),
|
||||
)
|
||||
plugins_openai: list[str] = UserConfigurable(
|
||||
default_factory=list, from_env=lambda: _safe_split(os.getenv("OPENAI_PLUGINS"))
|
||||
)
|
||||
|
||||
###############
|
||||
# Credentials #
|
||||
###############
|
||||
# OpenAI
|
||||
openai_credentials: Optional[OpenAICredentials] = None
|
||||
azure_config_file: Optional[Path] = UserConfigurable(
|
||||
default=AZURE_CONFIG_FILE,
|
||||
from_env=lambda: Path(f) if (f := os.getenv("AZURE_CONFIG_FILE")) else None,
|
||||
default=AZURE_CONFIG_FILE, from_env="AZURE_CONFIG_FILE"
|
||||
)
|
||||
|
||||
# Github
|
||||
@@ -230,7 +191,7 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
# Google
|
||||
google_api_key: Optional[str] = UserConfigurable(from_env="GOOGLE_API_KEY")
|
||||
google_custom_search_engine_id: Optional[str] = UserConfigurable(
|
||||
from_env=lambda: os.getenv("GOOGLE_CUSTOM_SEARCH_ENGINE_ID"),
|
||||
from_env="GOOGLE_CUSTOM_SEARCH_ENGINE_ID",
|
||||
)
|
||||
|
||||
# Huggingface
|
||||
@@ -241,22 +202,12 @@ class Config(SystemSettings, arbitrary_types_allowed=True):
|
||||
# Stable Diffusion
|
||||
sd_webui_auth: Optional[str] = UserConfigurable(from_env="SD_WEBUI_AUTH")
|
||||
|
||||
@validator("plugins", each_item=True)
|
||||
def validate_plugins(cls, p: AutoGPTPluginTemplate | Any):
|
||||
assert issubclass(
|
||||
p.__class__, AutoGPTPluginTemplate
|
||||
), f"{p} does not subclass AutoGPTPluginTemplate"
|
||||
assert (
|
||||
p.__class__.__name__ != "AutoGPTPluginTemplate"
|
||||
), f"Plugins must subclass AutoGPTPluginTemplate; {p} is a template instance"
|
||||
return p
|
||||
|
||||
@validator("openai_functions")
|
||||
def validate_openai_functions(cls, v: bool, values: dict[str, Any]):
|
||||
if v:
|
||||
smart_llm = values["smart_llm"]
|
||||
assert OPEN_AI_CHAT_MODELS[smart_llm].has_function_call_api, (
|
||||
f"Model {smart_llm} does not support OpenAI Functions. "
|
||||
assert CHAT_MODELS[smart_llm].has_function_call_api, (
|
||||
f"Model {smart_llm} does not support tool calling. "
|
||||
"Please disable OPENAI_FUNCTIONS or choose a suitable model."
|
||||
)
|
||||
return v
|
||||
@@ -276,7 +227,6 @@ class ConfigBuilder(Configurable[Config]):
|
||||
for k in {
|
||||
"ai_settings_file", # TODO: deprecate or repurpose
|
||||
"prompt_settings_file", # TODO: deprecate or repurpose
|
||||
"plugins_config_file", # TODO: move from project root
|
||||
"azure_config_file", # TODO: move from project root
|
||||
}:
|
||||
setattr(config, k, project_root / getattr(config, k))
|
||||
@@ -288,45 +238,56 @@ class ConfigBuilder(Configurable[Config]):
|
||||
):
|
||||
config.openai_credentials.load_azure_config(config_file)
|
||||
|
||||
config.plugins_config = PluginsConfig.load_config(
|
||||
config.plugins_config_file,
|
||||
config.plugins_denylist,
|
||||
config.plugins_allowlist,
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def assert_config_has_openai_api_key(config: Config) -> None:
|
||||
"""Check if the OpenAI API key is set in config.py or as an environment variable."""
|
||||
if not config.openai_credentials:
|
||||
print(
|
||||
Fore.RED
|
||||
+ "Please set your OpenAI API key in .env or as an environment variable."
|
||||
+ Fore.RESET
|
||||
key_pattern = r"^sk-(proj-)?\w{48}"
|
||||
openai_api_key = (
|
||||
config.openai_credentials.api_key.get_secret_value()
|
||||
if config.openai_credentials
|
||||
else ""
|
||||
)
|
||||
|
||||
# If there's no credentials or empty API key, prompt the user to set it
|
||||
if not openai_api_key:
|
||||
logger.error(
|
||||
"Please set your OpenAI API key in .env or as an environment variable."
|
||||
)
|
||||
print("You can get your key from https://platform.openai.com/account/api-keys")
|
||||
openai_api_key = input(
|
||||
"If you do have the key, please enter your OpenAI API key now:\n"
|
||||
logger.info(
|
||||
"You can get your key from https://platform.openai.com/account/api-keys"
|
||||
)
|
||||
openai_api_key = clean_input(
|
||||
config, "Please enter your OpenAI API key if you have it:"
|
||||
)
|
||||
key_pattern = r"^sk-\w{48}"
|
||||
openai_api_key = openai_api_key.strip()
|
||||
if re.search(key_pattern, openai_api_key):
|
||||
os.environ["OPENAI_API_KEY"] = openai_api_key
|
||||
config.openai_credentials = OpenAICredentials(
|
||||
api_key=SecretStr(openai_api_key)
|
||||
)
|
||||
if config.openai_credentials:
|
||||
config.openai_credentials.api_key = SecretStr(openai_api_key)
|
||||
else:
|
||||
config.openai_credentials = OpenAICredentials(
|
||||
api_key=SecretStr(openai_api_key)
|
||||
)
|
||||
print("OpenAI API key successfully set!")
|
||||
print(
|
||||
Fore.GREEN
|
||||
+ "OpenAI API key successfully set!\n"
|
||||
+ Fore.YELLOW
|
||||
+ "NOTE: The API key you've set is only temporary.\n"
|
||||
+ "For longer sessions, please set it in .env file"
|
||||
+ Fore.RESET
|
||||
f"{Fore.YELLOW}NOTE: The API key you've set is only temporary. "
|
||||
f"For longer sessions, please set it in the .env file{Fore.RESET}"
|
||||
)
|
||||
else:
|
||||
print("Invalid OpenAI API key!")
|
||||
print(f"{Fore.RED}Invalid OpenAI API key{Fore.RESET}")
|
||||
exit(1)
|
||||
# If key is set, but it looks invalid
|
||||
elif not re.search(key_pattern, openai_api_key):
|
||||
logger.error(
|
||||
"Invalid OpenAI API key! "
|
||||
"Please set your OpenAI API key in .env or as an environment variable."
|
||||
)
|
||||
logger.info(
|
||||
"You can get your key from https://platform.openai.com/account/api-keys"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
|
||||
def _safe_split(s: Union[str, None], sep: str = ",") -> list[str]:
|
||||
|
||||
@@ -4,7 +4,7 @@ from autogpt.core.configuration import SystemConfiguration, UserConfigurable
|
||||
from autogpt.core.planning.schema import Task, TaskType
|
||||
from autogpt.core.prompting import PromptStrategy
|
||||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||
from autogpt.core.prompting.utils import json_loads, to_numbered_list
|
||||
from autogpt.core.prompting.utils import to_numbered_list
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
@@ -194,9 +194,7 @@ class InitialPlan(PromptStrategy):
|
||||
f"LLM did not call {self._create_plan_function.name} function; "
|
||||
"plan creation failed"
|
||||
)
|
||||
parsed_response: object = json_loads(
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
parsed_response: object = response_content.tool_calls[0].function.arguments
|
||||
parsed_response["task_list"] = [
|
||||
Task.parse_obj(task) for task in parsed_response["task_list"]
|
||||
]
|
||||
|
||||
@@ -3,7 +3,6 @@ import logging
|
||||
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
|
||||
from autogpt.core.prompting import PromptStrategy
|
||||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||
from autogpt.core.prompting.utils import json_loads
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
@@ -141,9 +140,7 @@ class NameAndGoals(PromptStrategy):
|
||||
f"LLM did not call {self._create_agent_function} function; "
|
||||
"agent profile creation failed"
|
||||
)
|
||||
parsed_response = json_loads(
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
parsed_response = response_content.tool_calls[0].function.arguments
|
||||
except KeyError:
|
||||
logger.debug(f"Failed to parse this response content: {response_content}")
|
||||
raise
|
||||
|
||||
@@ -4,7 +4,7 @@ from autogpt.core.configuration import SystemConfiguration, UserConfigurable
|
||||
from autogpt.core.planning.schema import Task
|
||||
from autogpt.core.prompting import PromptStrategy
|
||||
from autogpt.core.prompting.schema import ChatPrompt, LanguageModelClassification
|
||||
from autogpt.core.prompting.utils import json_loads, to_numbered_list
|
||||
from autogpt.core.prompting.utils import to_numbered_list
|
||||
from autogpt.core.resource.model_providers import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
@@ -187,9 +187,7 @@ class NextAbility(PromptStrategy):
|
||||
raise ValueError("LLM did not call any function")
|
||||
|
||||
function_name = response_content.tool_calls[0].function.name
|
||||
function_arguments = json_loads(
|
||||
response_content.tool_calls[0].function.arguments
|
||||
)
|
||||
function_arguments = response_content.tool_calls[0].function.arguments
|
||||
parsed_response = {
|
||||
"motivation": function_arguments.pop("motivation"),
|
||||
"self_criticism": function_arguments.pop("self_criticism"),
|
||||
|
||||
@@ -24,6 +24,7 @@ class LanguageModelClassification(str, enum.Enum):
|
||||
class ChatPrompt(BaseModel):
|
||||
messages: list[ChatMessage]
|
||||
functions: list[CompletionModelFunction] = Field(default_factory=list)
|
||||
prefill_response: str = ""
|
||||
|
||||
def raw(self) -> list[ChatMessageDict]:
|
||||
return [m.dict() for m in self.messages]
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
import ast
|
||||
import json
|
||||
|
||||
|
||||
def to_numbered_list(
|
||||
items: list[str], no_items_response: str = "", **template_args
|
||||
) -> str:
|
||||
@@ -11,19 +7,3 @@ def to_numbered_list(
|
||||
)
|
||||
else:
|
||||
return no_items_response
|
||||
|
||||
|
||||
def json_loads(json_str: str):
|
||||
# TODO: this is a hack function for now. We'll see what errors show up in testing.
|
||||
# Can hopefully just replace with a call to ast.literal_eval.
|
||||
# Can't use json.loads because the function API still sometimes returns json strings
|
||||
# with minor issues like trailing commas.
|
||||
try:
|
||||
json_str = json_str[json_str.index("{") : json_str.rindex("}") + 1]
|
||||
return ast.literal_eval(json_str)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
try:
|
||||
print(f"json decode error {e}. trying literal eval")
|
||||
return ast.literal_eval(json_str)
|
||||
except Exception:
|
||||
breakpoint()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .multi import CHAT_MODELS, ModelName, MultiProvider
|
||||
from .openai import (
|
||||
OPEN_AI_CHAT_MODELS,
|
||||
OPEN_AI_EMBEDDING_MODELS,
|
||||
@@ -42,11 +43,13 @@ __all__ = [
|
||||
"ChatModelProvider",
|
||||
"ChatModelResponse",
|
||||
"CompletionModelFunction",
|
||||
"CHAT_MODELS",
|
||||
"Embedding",
|
||||
"EmbeddingModelInfo",
|
||||
"EmbeddingModelProvider",
|
||||
"EmbeddingModelResponse",
|
||||
"ModelInfo",
|
||||
"ModelName",
|
||||
"ModelProvider",
|
||||
"ModelProviderBudget",
|
||||
"ModelProviderCredentials",
|
||||
@@ -56,6 +59,7 @@ __all__ = [
|
||||
"ModelProviderUsage",
|
||||
"ModelResponse",
|
||||
"ModelTokenizer",
|
||||
"MultiProvider",
|
||||
"OPEN_AI_MODELS",
|
||||
"OPEN_AI_CHAT_MODELS",
|
||||
"OPEN_AI_EMBEDDING_MODELS",
|
||||
|
||||
@@ -0,0 +1,495 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, Optional, ParamSpec, TypeVar
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
import tiktoken
|
||||
from anthropic import APIConnectionError, APIStatusError
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.core.configuration import Configurable, UserConfigurable
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderCredentials,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
ToolResultMessage,
|
||||
)
|
||||
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic.types.beta.tools import MessageCreateParams
|
||||
from anthropic.types.beta.tools import ToolsBetaMessage as Message
|
||||
from anthropic.types.beta.tools import ToolsBetaMessageParam as MessageParam
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
|
||||
class AnthropicModelName(str, enum.Enum):
|
||||
CLAUDE3_OPUS_v1 = "claude-3-opus-20240229"
|
||||
CLAUDE3_SONNET_v1 = "claude-3-sonnet-20240229"
|
||||
CLAUDE3_HAIKU_v1 = "claude-3-haiku-20240307"
|
||||
|
||||
|
||||
ANTHROPIC_CHAT_MODELS = {
|
||||
info.name: info
|
||||
for info in [
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_OPUS_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=15 / 1e6,
|
||||
completion_token_cost=75 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_SONNET_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=3 / 1e6,
|
||||
completion_token_cost=15 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=AnthropicModelName.CLAUDE3_HAIKU_v1,
|
||||
provider_name=ModelProviderName.ANTHROPIC,
|
||||
prompt_token_cost=0.25 / 1e6,
|
||||
completion_token_cost=1.25 / 1e6,
|
||||
max_tokens=200000,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class AnthropicConfiguration(ModelProviderConfiguration):
|
||||
fix_failed_parse_tries: int = UserConfigurable(3)
|
||||
|
||||
|
||||
class AnthropicCredentials(ModelProviderCredentials):
|
||||
"""Credentials for Anthropic."""
|
||||
|
||||
api_key: SecretStr = UserConfigurable(from_env="ANTHROPIC_API_KEY")
|
||||
api_base: Optional[SecretStr] = UserConfigurable(
|
||||
default=None, from_env="ANTHROPIC_API_BASE_URL"
|
||||
)
|
||||
|
||||
def get_api_access_kwargs(self) -> dict[str, str]:
|
||||
return {
|
||||
k: (v.get_secret_value() if type(v) is SecretStr else v)
|
||||
for k, v in {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.api_base,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
|
||||
class AnthropicSettings(ModelProviderSettings):
|
||||
configuration: AnthropicConfiguration
|
||||
credentials: Optional[AnthropicCredentials]
|
||||
budget: ModelProviderBudget
|
||||
|
||||
|
||||
class AnthropicProvider(Configurable[AnthropicSettings], ChatModelProvider):
|
||||
default_settings = AnthropicSettings(
|
||||
name="anthropic_provider",
|
||||
description="Provides access to Anthropic's API.",
|
||||
configuration=AnthropicConfiguration(
|
||||
retries_per_request=7,
|
||||
),
|
||||
credentials=None,
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_settings: AnthropicSettings
|
||||
_configuration: AnthropicConfiguration
|
||||
_credentials: AnthropicCredentials
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[AnthropicSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
if not settings.credentials:
|
||||
settings.credentials = AnthropicCredentials.from_env()
|
||||
|
||||
super(AnthropicProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
self._client = AsyncAnthropic(**self._credentials.get_api_access_kwargs())
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
return list(ANTHROPIC_CHAT_MODELS.values())
|
||||
|
||||
def get_token_limit(self, model_name: str) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return ANTHROPIC_CHAT_MODELS[model_name].max_tokens
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: AnthropicModelName) -> ModelTokenizer:
|
||||
# HACK: No official tokenizer is available for Claude 3
|
||||
return tiktoken.encoding_for_model(model_name)
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: AnthropicModelName) -> int:
|
||||
return 0 # HACK: No official tokenizer is available for Claude 3
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls,
|
||||
messages: ChatMessage | list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
) -> int:
|
||||
return 0 # HACK: No official tokenizer is available for Claude 3
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: AnthropicModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the Anthropic API."""
|
||||
anthropic_messages, completion_kwargs = self._get_chat_completion_args(
|
||||
prompt_messages=model_prompt,
|
||||
model=model_name,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
total_cost = 0.0
|
||||
attempts = 0
|
||||
while True:
|
||||
completion_kwargs["messages"] = anthropic_messages.copy()
|
||||
if prefill_response:
|
||||
completion_kwargs["messages"].append(
|
||||
{"role": "assistant", "content": prefill_response}
|
||||
)
|
||||
|
||||
(
|
||||
_assistant_msg,
|
||||
cost,
|
||||
t_input,
|
||||
t_output,
|
||||
) = await self._create_chat_completion(completion_kwargs)
|
||||
total_cost += cost
|
||||
self._logger.debug(
|
||||
f"Completion usage: {t_input} input, {t_output} output "
|
||||
f"- ${round(cost, 5)}"
|
||||
)
|
||||
|
||||
# Merge prefill into generated response
|
||||
if prefill_response:
|
||||
first_text_block = next(
|
||||
b for b in _assistant_msg.content if b.type == "text"
|
||||
)
|
||||
first_text_block.text = prefill_response + first_text_block.text
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content="\n\n".join(
|
||||
b.text for b in _assistant_msg.content if b.type == "text"
|
||||
),
|
||||
tool_calls=self._parse_assistant_tool_calls(_assistant_msg),
|
||||
)
|
||||
|
||||
# If parsing the response fails, append the error to the prompt, and let the
|
||||
# LLM fix its mistake(s).
|
||||
attempts += 1
|
||||
tool_call_errors = []
|
||||
try:
|
||||
# Validate tool calls
|
||||
if assistant_msg.tool_calls and functions:
|
||||
tool_call_errors = validate_tool_calls(
|
||||
assistant_msg.tool_calls, functions
|
||||
)
|
||||
if tool_call_errors:
|
||||
raise ValueError(
|
||||
"Invalid tool use(s):\n"
|
||||
+ "\n".join(str(e) for e in tool_call_errors)
|
||||
)
|
||||
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.debug(
|
||||
f"Parsing failed on response: '''{_assistant_msg}'''"
|
||||
)
|
||||
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
||||
sentry_sdk.capture_exception(
|
||||
error=e,
|
||||
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
|
||||
)
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
anthropic_messages.append(
|
||||
_assistant_msg.dict(include={"role", "content"})
|
||||
)
|
||||
anthropic_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
*(
|
||||
# tool_result is required if last assistant message
|
||||
# had tool_use block(s)
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tc.id,
|
||||
"is_error": True,
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Not executed because parsing "
|
||||
"of your last message failed"
|
||||
if not tool_call_errors
|
||||
else str(e)
|
||||
if (
|
||||
e := next(
|
||||
(
|
||||
tce
|
||||
for tce in tool_call_errors
|
||||
if tce.name
|
||||
== tc.function.name
|
||||
),
|
||||
None,
|
||||
)
|
||||
)
|
||||
else "Not executed because validation "
|
||||
"of tool input failed",
|
||||
}
|
||||
],
|
||||
}
|
||||
for tc in assistant_msg.tool_calls or []
|
||||
),
|
||||
{
|
||||
"type": "text",
|
||||
"text": (
|
||||
"ERROR PARSING YOUR RESPONSE:\n\n"
|
||||
f"{e.__class__.__name__}: {e}"
|
||||
),
|
||||
},
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
if attempts > 1:
|
||||
self._logger.debug(
|
||||
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
|
||||
)
|
||||
|
||||
return ChatModelResponse(
|
||||
response=assistant_msg,
|
||||
parsed_result=parsed_result,
|
||||
model_info=ANTHROPIC_CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=t_input,
|
||||
completion_tokens_used=t_output,
|
||||
)
|
||||
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
prompt_messages: list[ChatMessage],
|
||||
model: AnthropicModelName,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> tuple[list[MessageParam], MessageCreateParams]:
|
||||
"""Prepare arguments for message completion API call.
|
||||
|
||||
Args:
|
||||
prompt_messages: List of ChatMessages.
|
||||
model: The model to use.
|
||||
functions: Optional list of functions available to the LLM.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
list[MessageParam]: Prompt messages for the Anthropic call
|
||||
dict[str, Any]: Any other kwargs for the Anthropic call
|
||||
"""
|
||||
kwargs["model"] = model
|
||||
|
||||
if functions:
|
||||
kwargs["tools"] = [
|
||||
{
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
name: param.to_dict()
|
||||
for name, param in f.parameters.items()
|
||||
},
|
||||
"required": [
|
||||
name
|
||||
for name, param in f.parameters.items()
|
||||
if param.required
|
||||
],
|
||||
},
|
||||
}
|
||||
for f in functions
|
||||
]
|
||||
|
||||
kwargs["max_tokens"] = max_output_tokens or 4096
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
|
||||
kwargs["extra_headers"].update(extra_headers.copy())
|
||||
|
||||
system_messages = [
|
||||
m for m in prompt_messages if m.role == ChatMessage.Role.SYSTEM
|
||||
]
|
||||
if (_n := len(system_messages)) > 1:
|
||||
self._logger.warning(
|
||||
f"Prompt has {_n} system messages; Anthropic supports only 1. "
|
||||
"They will be merged, and removed from the rest of the prompt."
|
||||
)
|
||||
kwargs["system"] = "\n\n".join(sm.content for sm in system_messages)
|
||||
|
||||
messages: list[MessageParam] = []
|
||||
for message in prompt_messages:
|
||||
if message.role == ChatMessage.Role.SYSTEM:
|
||||
continue
|
||||
elif message.role == ChatMessage.Role.USER:
|
||||
# Merge subsequent user messages
|
||||
if messages and (prev_msg := messages[-1])["role"] == "user":
|
||||
if isinstance(prev_msg["content"], str):
|
||||
prev_msg["content"] += f"\n\n{message.content}"
|
||||
else:
|
||||
assert isinstance(prev_msg["content"], list)
|
||||
prev_msg["content"].append(
|
||||
{"type": "text", "text": message.content}
|
||||
)
|
||||
else:
|
||||
messages.append({"role": "user", "content": message.content})
|
||||
# TODO: add support for image blocks
|
||||
elif message.role == ChatMessage.Role.ASSISTANT:
|
||||
if isinstance(message, AssistantChatMessage) and message.tool_calls:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
*(
|
||||
[{"type": "text", "text": message.content}]
|
||||
if message.content
|
||||
else []
|
||||
),
|
||||
*(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tc.id,
|
||||
"name": tc.function.name,
|
||||
"input": tc.function.arguments,
|
||||
}
|
||||
for tc in message.tool_calls
|
||||
),
|
||||
],
|
||||
}
|
||||
)
|
||||
elif message.content:
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
}
|
||||
)
|
||||
elif isinstance(message, ToolResultMessage):
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message.tool_call_id,
|
||||
"content": [{"type": "text", "text": message.content}],
|
||||
"is_error": message.is_error,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
return messages, kwargs # type: ignore
|
||||
|
||||
async def _create_chat_completion(
|
||||
self, completion_kwargs: MessageCreateParams
|
||||
) -> tuple[Message, float, int, int]:
|
||||
"""
|
||||
Create a chat completion using the Anthropic API with retry handling.
|
||||
|
||||
Params:
|
||||
completion_kwargs: Keyword arguments for an Anthropic Messages API call
|
||||
|
||||
Returns:
|
||||
Message: The message completion object
|
||||
float: The cost ($) of this completion
|
||||
int: Number of input tokens used
|
||||
int: Number of output tokens used
|
||||
"""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry(
|
||||
completion_kwargs: MessageCreateParams,
|
||||
) -> Message:
|
||||
return await self._client.beta.tools.messages.create(
|
||||
**completion_kwargs # type: ignore
|
||||
)
|
||||
|
||||
response = await _create_chat_completion_with_retry(completion_kwargs)
|
||||
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
model_info=ANTHROPIC_CHAT_MODELS[completion_kwargs["model"]],
|
||||
input_tokens_used=response.usage.input_tokens,
|
||||
output_tokens_used=response.usage.output_tokens,
|
||||
)
|
||||
return response, cost, response.usage.input_tokens, response.usage.output_tokens
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self, assistant_message: Message
|
||||
) -> list[AssistantToolCall]:
|
||||
return [
|
||||
AssistantToolCall(
|
||||
id=c.id,
|
||||
type="function",
|
||||
function=AssistantFunctionCall(name=c.name, arguments=c.input),
|
||||
)
|
||||
for c in assistant_message.content
|
||||
if c.type == "tool_use"
|
||||
]
|
||||
|
||||
def _retry_api_request(self, func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
return tenacity.retry(
|
||||
retry=(
|
||||
tenacity.retry_if_exception_type(APIConnectionError)
|
||||
| tenacity.retry_if_exception(
|
||||
lambda e: isinstance(e, APIStatusError) and e.status_code >= 500
|
||||
)
|
||||
),
|
||||
wait=tenacity.wait_exponential(),
|
||||
stop=tenacity.stop_after_attempt(self._configuration.retries_per_request),
|
||||
after=tenacity.after_log(self._logger, logging.DEBUG),
|
||||
)(func)
|
||||
|
||||
def __repr__(self):
|
||||
return "AnthropicProvider()"
|
||||
162
autogpts/autogpt/autogpt/core/resource/model_providers/multi.py
Normal file
162
autogpts/autogpt/autogpt/core/resource/model_providers/multi.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Iterator, Optional, TypeVar
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from autogpt.core.configuration import Configurable
|
||||
|
||||
from .anthropic import ANTHROPIC_CHAT_MODELS, AnthropicModelName, AnthropicProvider
|
||||
from .openai import OPEN_AI_CHAT_MODELS, OpenAIModelName, OpenAIProvider
|
||||
from .schema import (
|
||||
AssistantChatMessage,
|
||||
ChatMessage,
|
||||
ChatModelInfo,
|
||||
ChatModelProvider,
|
||||
ChatModelResponse,
|
||||
CompletionModelFunction,
|
||||
ModelProviderBudget,
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderName,
|
||||
ModelProviderSettings,
|
||||
ModelTokenizer,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
ModelName = AnthropicModelName | OpenAIModelName
|
||||
|
||||
CHAT_MODELS = {**ANTHROPIC_CHAT_MODELS, **OPEN_AI_CHAT_MODELS}
|
||||
|
||||
|
||||
class MultiProvider(Configurable[ModelProviderSettings], ChatModelProvider):
|
||||
default_settings = ModelProviderSettings(
|
||||
name="multi_provider",
|
||||
description=(
|
||||
"Provides access to all of the available models, regardless of provider."
|
||||
),
|
||||
configuration=ModelProviderConfiguration(
|
||||
retries_per_request=7,
|
||||
),
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
_provider_instances: dict[ModelProviderName, ChatModelProvider]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
super(MultiProvider, self).__init__(settings=settings, logger=logger)
|
||||
self._budget = self._settings.budget or ModelProviderBudget()
|
||||
|
||||
self._provider_instances = {}
|
||||
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
models = []
|
||||
for provider in self.get_available_providers():
|
||||
models.extend(await provider.get_available_models())
|
||||
return models
|
||||
|
||||
def get_token_limit(self, model_name: ModelName) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
return self.get_model_provider(model_name).get_token_limit(model_name)
|
||||
|
||||
@classmethod
|
||||
def get_tokenizer(cls, model_name: ModelName) -> ModelTokenizer:
|
||||
return cls._get_model_provider_class(model_name).get_tokenizer(model_name)
|
||||
|
||||
@classmethod
|
||||
def count_tokens(cls, text: str, model_name: ModelName) -> int:
|
||||
return cls._get_model_provider_class(model_name).count_tokens(
|
||||
text=text, model_name=model_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def count_message_tokens(
|
||||
cls, messages: ChatMessage | list[ChatMessage], model_name: ModelName
|
||||
) -> int:
|
||||
return cls._get_model_provider_class(model_name).count_message_tokens(
|
||||
messages=messages, model_name=model_name
|
||||
)
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: ModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the Anthropic API."""
|
||||
return await self.get_model_provider(model_name).create_chat_completion(
|
||||
model_prompt=model_prompt,
|
||||
model_name=model_name,
|
||||
completion_parser=completion_parser,
|
||||
functions=functions,
|
||||
max_output_tokens=max_output_tokens,
|
||||
prefill_response=prefill_response,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_model_provider(self, model: ModelName) -> ChatModelProvider:
|
||||
model_info = CHAT_MODELS[model]
|
||||
return self._get_provider(model_info.provider_name)
|
||||
|
||||
def get_available_providers(self) -> Iterator[ChatModelProvider]:
|
||||
for provider_name in ModelProviderName:
|
||||
try:
|
||||
yield self._get_provider(provider_name)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_provider(self, provider_name: ModelProviderName) -> ChatModelProvider:
|
||||
_provider = self._provider_instances.get(provider_name)
|
||||
if not _provider:
|
||||
Provider = self._get_provider_class(provider_name)
|
||||
settings = Provider.default_settings.copy(deep=True)
|
||||
settings.budget = self._budget
|
||||
settings.configuration.extra_request_headers.update(
|
||||
self._settings.configuration.extra_request_headers
|
||||
)
|
||||
if settings.credentials is None:
|
||||
try:
|
||||
Credentials = settings.__fields__["credentials"].type_
|
||||
settings.credentials = Credentials.from_env()
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
f"{provider_name} is unavailable: can't load credentials"
|
||||
) from e
|
||||
|
||||
self._provider_instances[provider_name] = _provider = Provider(
|
||||
settings=settings, logger=self._logger
|
||||
)
|
||||
_provider._budget = self._budget # Object binding not preserved by Pydantic
|
||||
return _provider
|
||||
|
||||
@classmethod
|
||||
def _get_model_provider_class(
|
||||
cls, model_name: ModelName
|
||||
) -> type[AnthropicProvider | OpenAIProvider]:
|
||||
return cls._get_provider_class(CHAT_MODELS[model_name].provider_name)
|
||||
|
||||
@classmethod
|
||||
def _get_provider_class(
|
||||
cls, provider_name: ModelProviderName
|
||||
) -> type[AnthropicProvider | OpenAIProvider]:
|
||||
try:
|
||||
return {
|
||||
ModelProviderName.ANTHROPIC: AnthropicProvider,
|
||||
ModelProviderName.OPENAI: OpenAIProvider,
|
||||
}[provider_name]
|
||||
except KeyError:
|
||||
raise ValueError(f"{provider_name} is not a known provider") from None
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}()"
|
||||
@@ -1,21 +1,26 @@
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
|
||||
from typing import Any, Callable, Coroutine, Iterator, Optional, ParamSpec, TypeVar
|
||||
|
||||
import sentry_sdk
|
||||
import tenacity
|
||||
import tiktoken
|
||||
import yaml
|
||||
from openai._exceptions import APIStatusError, RateLimitError
|
||||
from openai.types import CreateEmbeddingResponse
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageParam,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.core.configuration import Configurable, UserConfigurable
|
||||
from autogpt.core.resource.model_providers.schema import (
|
||||
AssistantChatMessage,
|
||||
AssistantFunctionCall,
|
||||
AssistantToolCall,
|
||||
AssistantToolCallDict,
|
||||
ChatMessage,
|
||||
@@ -31,27 +36,30 @@ from autogpt.core.resource.model_providers.schema import (
|
||||
ModelProviderConfiguration,
|
||||
ModelProviderCredentials,
|
||||
ModelProviderName,
|
||||
ModelProviderService,
|
||||
ModelProviderSettings,
|
||||
ModelProviderUsage,
|
||||
ModelTokenizer,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.core.utils.json_utils import json_loads
|
||||
|
||||
from .utils import validate_tool_calls
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
OpenAIEmbeddingParser = Callable[[Embedding], Embedding]
|
||||
OpenAIChatParser = Callable[[str], dict]
|
||||
|
||||
|
||||
class OpenAIModelName(str, enum.Enum):
|
||||
ADA = "text-embedding-ada-002"
|
||||
EMBEDDING_v2 = "text-embedding-ada-002"
|
||||
EMBEDDING_v3_S = "text-embedding-3-small"
|
||||
EMBEDDING_v3_L = "text-embedding-3-large"
|
||||
|
||||
GPT3_v1 = "gpt-3.5-turbo-0301"
|
||||
GPT3_v2 = "gpt-3.5-turbo-0613"
|
||||
GPT3_v2_16k = "gpt-3.5-turbo-16k-0613"
|
||||
GPT3_v3 = "gpt-3.5-turbo-1106"
|
||||
GPT3_v4 = "gpt-3.5-turbo-0125"
|
||||
GPT3_ROLLING = "gpt-3.5-turbo"
|
||||
GPT3_ROLLING_16k = "gpt-3.5-turbo-16k"
|
||||
GPT3 = GPT3_ROLLING
|
||||
@@ -62,24 +70,43 @@ class OpenAIModelName(str, enum.Enum):
|
||||
GPT4_v2 = "gpt-4-0613"
|
||||
GPT4_v2_32k = "gpt-4-32k-0613"
|
||||
GPT4_v3 = "gpt-4-1106-preview"
|
||||
GPT4_v3_VISION = "gpt-4-1106-vision-preview"
|
||||
GPT4_v4 = "gpt-4-0125-preview"
|
||||
GPT4_v5 = "gpt-4-turbo-2024-04-09"
|
||||
GPT4_ROLLING = "gpt-4"
|
||||
GPT4_ROLLING_32k = "gpt-4-32k"
|
||||
GPT4_TURBO = "gpt-4-turbo-preview"
|
||||
GPT4_TURBO = "gpt-4-turbo"
|
||||
GPT4_TURBO_PREVIEW = "gpt-4-turbo-preview"
|
||||
GPT4_VISION = "gpt-4-vision-preview"
|
||||
GPT4 = GPT4_ROLLING
|
||||
GPT4_32k = GPT4_ROLLING_32k
|
||||
|
||||
|
||||
OPEN_AI_EMBEDDING_MODELS = {
|
||||
OpenAIModelName.ADA: EmbeddingModelInfo(
|
||||
name=OpenAIModelName.ADA,
|
||||
service=ModelProviderService.EMBEDDING,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.0001 / 1000,
|
||||
max_tokens=8191,
|
||||
embedding_dimensions=1536,
|
||||
),
|
||||
info.name: info
|
||||
for info in [
|
||||
EmbeddingModelInfo(
|
||||
name=OpenAIModelName.EMBEDDING_v2,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.0001 / 1000,
|
||||
max_tokens=8191,
|
||||
embedding_dimensions=1536,
|
||||
),
|
||||
EmbeddingModelInfo(
|
||||
name=OpenAIModelName.EMBEDDING_v3_S,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.00002 / 1000,
|
||||
max_tokens=8191,
|
||||
embedding_dimensions=1536,
|
||||
),
|
||||
EmbeddingModelInfo(
|
||||
name=OpenAIModelName.EMBEDDING_v3_L,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.00013 / 1000,
|
||||
max_tokens=8191,
|
||||
embedding_dimensions=3072,
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -87,8 +114,7 @@ OPEN_AI_CHAT_MODELS = {
|
||||
info.name: info
|
||||
for info in [
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT3,
|
||||
service=ModelProviderService.CHAT,
|
||||
name=OpenAIModelName.GPT3_v1,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.0015 / 1000,
|
||||
completion_token_cost=0.002 / 1000,
|
||||
@@ -96,8 +122,7 @@ OPEN_AI_CHAT_MODELS = {
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT3_16k,
|
||||
service=ModelProviderService.CHAT,
|
||||
name=OpenAIModelName.GPT3_v2_16k,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.003 / 1000,
|
||||
completion_token_cost=0.004 / 1000,
|
||||
@@ -106,7 +131,6 @@ OPEN_AI_CHAT_MODELS = {
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT3_v3,
|
||||
service=ModelProviderService.CHAT,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.001 / 1000,
|
||||
completion_token_cost=0.002 / 1000,
|
||||
@@ -114,8 +138,15 @@ OPEN_AI_CHAT_MODELS = {
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT4,
|
||||
service=ModelProviderService.CHAT,
|
||||
name=OpenAIModelName.GPT3_v4,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.0005 / 1000,
|
||||
completion_token_cost=0.0015 / 1000,
|
||||
max_tokens=16384,
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT4_v1,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.03 / 1000,
|
||||
completion_token_cost=0.06 / 1000,
|
||||
@@ -123,8 +154,7 @@ OPEN_AI_CHAT_MODELS = {
|
||||
has_function_call_api=True,
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT4_32k,
|
||||
service=ModelProviderService.CHAT,
|
||||
name=OpenAIModelName.GPT4_v1_32k,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.06 / 1000,
|
||||
completion_token_cost=0.12 / 1000,
|
||||
@@ -133,7 +163,6 @@ OPEN_AI_CHAT_MODELS = {
|
||||
),
|
||||
ChatModelInfo(
|
||||
name=OpenAIModelName.GPT4_TURBO,
|
||||
service=ModelProviderService.CHAT,
|
||||
provider_name=ModelProviderName.OPENAI,
|
||||
prompt_token_cost=0.01 / 1000,
|
||||
completion_token_cost=0.03 / 1000,
|
||||
@@ -144,19 +173,26 @@ OPEN_AI_CHAT_MODELS = {
|
||||
}
|
||||
# Copy entries for models with equivalent specs
|
||||
chat_model_mapping = {
|
||||
OpenAIModelName.GPT3: [OpenAIModelName.GPT3_v1, OpenAIModelName.GPT3_v2],
|
||||
OpenAIModelName.GPT3_16k: [OpenAIModelName.GPT3_v2_16k],
|
||||
OpenAIModelName.GPT4: [OpenAIModelName.GPT4_v1, OpenAIModelName.GPT4_v2],
|
||||
OpenAIModelName.GPT4_32k: [
|
||||
OpenAIModelName.GPT4_v1_32k,
|
||||
OpenAIModelName.GPT3_v1: [OpenAIModelName.GPT3_v2],
|
||||
OpenAIModelName.GPT3_v2_16k: [OpenAIModelName.GPT3_16k],
|
||||
OpenAIModelName.GPT3_v4: [OpenAIModelName.GPT3_ROLLING],
|
||||
OpenAIModelName.GPT4_v1: [OpenAIModelName.GPT4_v2, OpenAIModelName.GPT4_ROLLING],
|
||||
OpenAIModelName.GPT4_v1_32k: [
|
||||
OpenAIModelName.GPT4_v2_32k,
|
||||
OpenAIModelName.GPT4_32k,
|
||||
],
|
||||
OpenAIModelName.GPT4_TURBO: [
|
||||
OpenAIModelName.GPT4_v3,
|
||||
OpenAIModelName.GPT4_v3_VISION,
|
||||
OpenAIModelName.GPT4_VISION,
|
||||
OpenAIModelName.GPT4_v4,
|
||||
OpenAIModelName.GPT4_TURBO_PREVIEW,
|
||||
OpenAIModelName.GPT4_v5,
|
||||
],
|
||||
OpenAIModelName.GPT4_TURBO: [OpenAIModelName.GPT4_v3, OpenAIModelName.GPT4_v4],
|
||||
}
|
||||
for base, copies in chat_model_mapping.items():
|
||||
for copy in copies:
|
||||
copy_info = ChatModelInfo(**OPEN_AI_CHAT_MODELS[base].__dict__)
|
||||
copy_info.name = copy
|
||||
copy_info = OPEN_AI_CHAT_MODELS[base].copy(update={"name": copy})
|
||||
OPEN_AI_CHAT_MODELS[copy] = copy_info
|
||||
if copy.endswith(("-0301", "-0314")):
|
||||
copy_info.has_function_call_api = False
|
||||
@@ -205,7 +241,8 @@ class OpenAICredentials(ModelProviderCredentials):
|
||||
}
|
||||
if self.api_type == "azure":
|
||||
kwargs["api_version"] = self.api_version
|
||||
kwargs["azure_endpoint"] = self.azure_endpoint
|
||||
assert self.azure_endpoint, "Azure endpoint not configured"
|
||||
kwargs["azure_endpoint"] = self.azure_endpoint.get_secret_value()
|
||||
return kwargs
|
||||
|
||||
def get_model_access_kwargs(self, model: str) -> dict[str, str]:
|
||||
@@ -217,7 +254,7 @@ class OpenAICredentials(ModelProviderCredentials):
|
||||
|
||||
def load_azure_config(self, config_file: Path) -> None:
|
||||
with open(config_file) as file:
|
||||
config_params = yaml.load(file, Loader=yaml.FullLoader) or {}
|
||||
config_params = yaml.load(file, Loader=yaml.SafeLoader) or {}
|
||||
|
||||
try:
|
||||
assert config_params.get(
|
||||
@@ -257,35 +294,28 @@ class OpenAIProvider(
|
||||
name="openai_provider",
|
||||
description="Provides access to OpenAI's API.",
|
||||
configuration=OpenAIConfiguration(
|
||||
retries_per_request=10,
|
||||
retries_per_request=7,
|
||||
),
|
||||
credentials=None,
|
||||
budget=ModelProviderBudget(
|
||||
total_budget=math.inf,
|
||||
total_cost=0.0,
|
||||
remaining_budget=math.inf,
|
||||
usage=ModelProviderUsage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
),
|
||||
),
|
||||
budget=ModelProviderBudget(),
|
||||
)
|
||||
|
||||
_budget: ModelProviderBudget
|
||||
_settings: OpenAISettings
|
||||
_configuration: OpenAIConfiguration
|
||||
_credentials: OpenAICredentials
|
||||
_budget: ModelProviderBudget
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: OpenAISettings,
|
||||
logger: logging.Logger,
|
||||
settings: Optional[OpenAISettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
self._settings = settings
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
if not settings.credentials:
|
||||
settings.credentials = OpenAICredentials.from_env()
|
||||
|
||||
assert settings.credentials, "Cannot create OpenAIProvider without credentials"
|
||||
self._configuration = settings.configuration
|
||||
self._credentials = settings.credentials
|
||||
self._budget = settings.budget
|
||||
super(OpenAIProvider, self).__init__(settings=settings, logger=logger)
|
||||
|
||||
if self._credentials.api_type == "azure":
|
||||
from openai import AsyncAzureOpenAI
|
||||
@@ -298,7 +328,9 @@ class OpenAIProvider(
|
||||
|
||||
self._client = AsyncOpenAI(**self._credentials.get_api_access_kwargs())
|
||||
|
||||
self._logger = logger
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
_models = (await self._client.models.list()).data
|
||||
return [OPEN_AI_MODELS[m.id] for m in _models if m.id in OPEN_AI_MODELS]
|
||||
|
||||
def get_token_limit(self, model_name: str) -> int:
|
||||
"""Get the token limit for a given model."""
|
||||
@@ -362,79 +394,104 @@ class OpenAIProvider(
|
||||
model_name: OpenAIModelName,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "", # not supported by OpenAI
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
"""Create a completion using the OpenAI API."""
|
||||
"""Create a completion using the OpenAI API and parse it."""
|
||||
|
||||
completion_kwargs = self._get_completion_kwargs(model_name, functions, **kwargs)
|
||||
tool_calls_compat_mode = functions and "tools" not in completion_kwargs
|
||||
if "messages" in completion_kwargs:
|
||||
model_prompt += completion_kwargs["messages"]
|
||||
del completion_kwargs["messages"]
|
||||
openai_messages, completion_kwargs = self._get_chat_completion_args(
|
||||
model_prompt=model_prompt,
|
||||
model_name=model_name,
|
||||
functions=functions,
|
||||
max_tokens=max_output_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
tool_calls_compat_mode = bool(functions and "tools" not in completion_kwargs)
|
||||
|
||||
cost = 0.0
|
||||
total_cost = 0.0
|
||||
attempts = 0
|
||||
while True:
|
||||
_response = await self._create_chat_completion(
|
||||
messages=model_prompt,
|
||||
_response, _cost, t_input, t_output = await self._create_chat_completion(
|
||||
messages=openai_messages,
|
||||
**completion_kwargs,
|
||||
)
|
||||
|
||||
_assistant_msg = _response.choices[0].message
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
tool_calls=(
|
||||
[AssistantToolCall(**tc.dict()) for tc in _assistant_msg.tool_calls]
|
||||
if _assistant_msg.tool_calls
|
||||
else None
|
||||
),
|
||||
)
|
||||
response = ChatModelResponse(
|
||||
response=assistant_msg,
|
||||
model_info=OPEN_AI_CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=(
|
||||
_response.usage.prompt_tokens if _response.usage else 0
|
||||
),
|
||||
completion_tokens_used=(
|
||||
_response.usage.completion_tokens if _response.usage else 0
|
||||
),
|
||||
)
|
||||
cost += self._budget.update_usage_and_cost(response)
|
||||
self._logger.debug(
|
||||
f"Completion usage: {response.prompt_tokens_used} input, "
|
||||
f"{response.completion_tokens_used} output - ${round(cost, 5)}"
|
||||
)
|
||||
total_cost += _cost
|
||||
|
||||
# If parsing the response fails, append the error to the prompt, and let the
|
||||
# LLM fix its mistake(s).
|
||||
try:
|
||||
attempts += 1
|
||||
attempts += 1
|
||||
parse_errors: list[Exception] = []
|
||||
|
||||
if (
|
||||
tool_calls_compat_mode
|
||||
and assistant_msg.content
|
||||
and not assistant_msg.tool_calls
|
||||
):
|
||||
assistant_msg.tool_calls = list(
|
||||
_tool_calls_compat_extract_calls(assistant_msg.content)
|
||||
_assistant_msg = _response.choices[0].message
|
||||
|
||||
tool_calls, _errors = self._parse_assistant_tool_calls(
|
||||
_assistant_msg, tool_calls_compat_mode
|
||||
)
|
||||
parse_errors += _errors
|
||||
|
||||
# Validate tool calls
|
||||
if not parse_errors and tool_calls and functions:
|
||||
parse_errors += validate_tool_calls(tool_calls, functions)
|
||||
|
||||
assistant_msg = AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
parsed_result: _T = None # type: ignore
|
||||
if not parse_errors:
|
||||
try:
|
||||
parsed_result = completion_parser(assistant_msg)
|
||||
except Exception as e:
|
||||
parse_errors.append(e)
|
||||
|
||||
if not parse_errors:
|
||||
if attempts > 1:
|
||||
self._logger.debug(
|
||||
f"Total cost for {attempts} attempts: ${round(total_cost, 5)}"
|
||||
)
|
||||
|
||||
return ChatModelResponse(
|
||||
response=AssistantChatMessage(
|
||||
content=_assistant_msg.content,
|
||||
tool_calls=tool_calls or None,
|
||||
),
|
||||
parsed_result=parsed_result,
|
||||
model_info=OPEN_AI_CHAT_MODELS[model_name],
|
||||
prompt_tokens_used=t_input,
|
||||
completion_tokens_used=t_output,
|
||||
)
|
||||
|
||||
else:
|
||||
self._logger.debug(
|
||||
f"Parsing failed on response: '''{_assistant_msg}'''"
|
||||
)
|
||||
parse_errors_fmt = "\n\n".join(
|
||||
f"{e.__class__.__name__}: {e}" for e in parse_errors
|
||||
)
|
||||
self._logger.warning(
|
||||
f"Parsing attempt #{attempts} failed: {parse_errors_fmt}"
|
||||
)
|
||||
for e in parse_errors:
|
||||
sentry_sdk.capture_exception(
|
||||
error=e,
|
||||
extras={"assistant_msg": _assistant_msg, "i_attempt": attempts},
|
||||
)
|
||||
|
||||
response.parsed_result = completion_parser(assistant_msg)
|
||||
break
|
||||
except Exception as e:
|
||||
self._logger.warning(f"Parsing attempt #{attempts} failed: {e}")
|
||||
self._logger.debug(f"Parsing failed on response: '''{assistant_msg}'''")
|
||||
if attempts < self._configuration.fix_failed_parse_tries:
|
||||
model_prompt.append(
|
||||
ChatMessage.system(f"ERROR PARSING YOUR RESPONSE:\n\n{e}")
|
||||
openai_messages.append(_assistant_msg.dict(exclude_none=True))
|
||||
openai_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
f"ERROR PARSING YOUR RESPONSE:\n\n{parse_errors_fmt}"
|
||||
),
|
||||
}
|
||||
)
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
if attempts > 1:
|
||||
self._logger.debug(f"Total cost for {attempts} attempts: ${round(cost, 5)}")
|
||||
|
||||
return response
|
||||
raise parse_errors[0]
|
||||
|
||||
async def create_embedding(
|
||||
self,
|
||||
@@ -456,21 +513,24 @@ class OpenAIProvider(
|
||||
self._budget.update_usage_and_cost(response)
|
||||
return response
|
||||
|
||||
def _get_completion_kwargs(
|
||||
def _get_chat_completion_args(
|
||||
self,
|
||||
model_prompt: list[ChatMessage],
|
||||
model_name: OpenAIModelName,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Get kwargs for completion API call.
|
||||
) -> tuple[list[ChatCompletionMessageParam], dict[str, Any]]:
|
||||
"""Prepare chat completion arguments and keyword arguments for API call.
|
||||
|
||||
Args:
|
||||
model: The model to use.
|
||||
kwargs: Keyword arguments to override the default values.
|
||||
model_prompt: List of ChatMessages.
|
||||
model_name: The model to use.
|
||||
functions: Optional list of functions available to the LLM.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The kwargs for the chat API call.
|
||||
|
||||
list[ChatCompletionMessageParam]: Prompt messages for the OpenAI call
|
||||
dict[str, Any]: Any other kwargs for the OpenAI call
|
||||
"""
|
||||
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
|
||||
|
||||
@@ -490,11 +550,22 @@ class OpenAIProvider(
|
||||
_functions_compat_fix_kwargs(functions, kwargs)
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {}).update(
|
||||
extra_headers.copy()
|
||||
)
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
|
||||
kwargs["extra_headers"].update(extra_headers.copy())
|
||||
|
||||
return kwargs
|
||||
if "messages" in kwargs:
|
||||
model_prompt += kwargs["messages"]
|
||||
del kwargs["messages"]
|
||||
|
||||
openai_messages: list[ChatCompletionMessageParam] = [
|
||||
message.dict(
|
||||
include={"role", "content", "tool_calls", "name"},
|
||||
exclude_none=True,
|
||||
)
|
||||
for message in model_prompt
|
||||
]
|
||||
|
||||
return openai_messages, kwargs
|
||||
|
||||
def _get_embedding_kwargs(
|
||||
self,
|
||||
@@ -514,31 +585,108 @@ class OpenAIProvider(
|
||||
kwargs.update(self._credentials.get_model_access_kwargs(model_name))
|
||||
|
||||
if extra_headers := self._configuration.extra_request_headers:
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {}).update(
|
||||
extra_headers.copy()
|
||||
)
|
||||
kwargs["extra_headers"] = kwargs.get("extra_headers", {})
|
||||
kwargs["extra_headers"].update(extra_headers.copy())
|
||||
|
||||
return kwargs
|
||||
|
||||
def _create_chat_completion(
|
||||
self, messages: list[ChatMessage], *_, **kwargs
|
||||
) -> Coroutine[None, None, ChatCompletion]:
|
||||
"""Create a chat completion using the OpenAI API with retry handling."""
|
||||
async def _create_chat_completion(
|
||||
self,
|
||||
messages: list[ChatCompletionMessageParam],
|
||||
model: OpenAIModelName,
|
||||
*_,
|
||||
**kwargs,
|
||||
) -> tuple[ChatCompletion, float, int, int]:
|
||||
"""
|
||||
Create a chat completion using the OpenAI API with retry handling.
|
||||
|
||||
Params:
|
||||
openai_messages: List of OpenAI-consumable message dict objects
|
||||
model: The model to use for the completion
|
||||
|
||||
Returns:
|
||||
ChatCompletion: The chat completion response object
|
||||
float: The cost ($) of this completion
|
||||
int: Number of prompt tokens used
|
||||
int: Number of completion tokens used
|
||||
"""
|
||||
|
||||
@self._retry_api_request
|
||||
async def _create_chat_completion_with_retry(
|
||||
messages: list[ChatMessage], *_, **kwargs
|
||||
messages: list[ChatCompletionMessageParam], **kwargs
|
||||
) -> ChatCompletion:
|
||||
raw_messages = [
|
||||
message.dict(include={"role", "content", "tool_calls", "name"})
|
||||
for message in messages
|
||||
]
|
||||
return await self._client.chat.completions.create(
|
||||
messages=raw_messages, # type: ignore
|
||||
messages=messages, # type: ignore
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _create_chat_completion_with_retry(messages, *_, **kwargs)
|
||||
completion = await _create_chat_completion_with_retry(
|
||||
messages, model=model, **kwargs
|
||||
)
|
||||
|
||||
if completion.usage:
|
||||
prompt_tokens_used = completion.usage.prompt_tokens
|
||||
completion_tokens_used = completion.usage.completion_tokens
|
||||
else:
|
||||
prompt_tokens_used = completion_tokens_used = 0
|
||||
|
||||
cost = self._budget.update_usage_and_cost(
|
||||
model_info=OPEN_AI_CHAT_MODELS[model],
|
||||
input_tokens_used=prompt_tokens_used,
|
||||
output_tokens_used=completion_tokens_used,
|
||||
)
|
||||
self._logger.debug(
|
||||
f"Completion usage: {prompt_tokens_used} input, "
|
||||
f"{completion_tokens_used} output - ${round(cost, 5)}"
|
||||
)
|
||||
return completion, cost, prompt_tokens_used, completion_tokens_used
|
||||
|
||||
def _parse_assistant_tool_calls(
|
||||
self, assistant_message: ChatCompletionMessage, compat_mode: bool = False
|
||||
):
|
||||
tool_calls: list[AssistantToolCall] = []
|
||||
parse_errors: list[Exception] = []
|
||||
|
||||
if assistant_message.tool_calls:
|
||||
for _tc in assistant_message.tool_calls:
|
||||
try:
|
||||
parsed_arguments = json_loads(_tc.function.arguments)
|
||||
except Exception as e:
|
||||
err_message = (
|
||||
f"Decoding arguments for {_tc.function.name} failed: "
|
||||
+ str(e.args[0])
|
||||
)
|
||||
parse_errors.append(
|
||||
type(e)(err_message, *e.args[1:]).with_traceback(
|
||||
e.__traceback__
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
tool_calls.append(
|
||||
AssistantToolCall(
|
||||
id=_tc.id,
|
||||
type=_tc.type,
|
||||
function=AssistantFunctionCall(
|
||||
name=_tc.function.name,
|
||||
arguments=parsed_arguments,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# If parsing of all tool calls succeeds in the end, we ignore any issues
|
||||
if len(tool_calls) == len(assistant_message.tool_calls):
|
||||
parse_errors = []
|
||||
|
||||
elif compat_mode and assistant_message.content:
|
||||
try:
|
||||
tool_calls = list(
|
||||
_tool_calls_compat_extract_calls(assistant_message.content)
|
||||
)
|
||||
except Exception as e:
|
||||
parse_errors.append(e)
|
||||
|
||||
return tool_calls, parse_errors
|
||||
|
||||
def _create_embedding(
|
||||
self, text: str, *_, **kwargs
|
||||
@@ -713,20 +861,21 @@ def _functions_compat_fix_kwargs(
|
||||
|
||||
|
||||
def _tool_calls_compat_extract_calls(response: str) -> Iterator[AssistantToolCall]:
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
|
||||
logging.debug(f"Trying to extract tool calls from response:\n{response}")
|
||||
|
||||
if response[0] == "[":
|
||||
tool_calls: list[AssistantToolCallDict] = json.loads(response)
|
||||
tool_calls: list[AssistantToolCallDict] = json_loads(response)
|
||||
else:
|
||||
block = re.search(r"```(?:tool_calls)?\n(.*)\n```\s*$", response, re.DOTALL)
|
||||
if not block:
|
||||
raise ValueError("Could not find tool_calls block in response")
|
||||
tool_calls: list[AssistantToolCallDict] = json.loads(block.group(1))
|
||||
tool_calls: list[AssistantToolCallDict] = json_loads(block.group(1))
|
||||
|
||||
for t in tool_calls:
|
||||
t["id"] = str(uuid.uuid4())
|
||||
t["function"]["arguments"] = str(t["function"]["arguments"]) # HACK
|
||||
|
||||
yield AssistantToolCall.parse_obj(t)
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import abc
|
||||
import enum
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Generic,
|
||||
@@ -24,6 +28,10 @@ from autogpt.core.resource.schema import (
|
||||
ResourceType,
|
||||
)
|
||||
from autogpt.core.utils.json_schema import JSONSchema
|
||||
from autogpt.logs.utils import fmt_kwargs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jsonschema import ValidationError
|
||||
|
||||
|
||||
class ModelProviderService(str, enum.Enum):
|
||||
@@ -36,6 +44,7 @@ class ModelProviderService(str, enum.Enum):
|
||||
|
||||
class ModelProviderName(str, enum.Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -44,16 +53,14 @@ class ChatMessage(BaseModel):
|
||||
SYSTEM = "system"
|
||||
ASSISTANT = "assistant"
|
||||
|
||||
TOOL = "tool"
|
||||
"""May be used for the result of tool calls"""
|
||||
FUNCTION = "function"
|
||||
"""May be used for the return value of function calls"""
|
||||
|
||||
role: Role
|
||||
content: str
|
||||
|
||||
@staticmethod
|
||||
def assistant(content: str) -> "ChatMessage":
|
||||
return ChatMessage(role=ChatMessage.Role.ASSISTANT, content=content)
|
||||
|
||||
@staticmethod
|
||||
def user(content: str) -> "ChatMessage":
|
||||
return ChatMessage(role=ChatMessage.Role.USER, content=content)
|
||||
@@ -70,30 +77,39 @@ class ChatMessageDict(TypedDict):
|
||||
|
||||
class AssistantFunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.name}({fmt_kwargs(self.arguments)})"
|
||||
|
||||
|
||||
class AssistantFunctionCallDict(TypedDict):
|
||||
name: str
|
||||
arguments: str
|
||||
arguments: dict[str, Any]
|
||||
|
||||
|
||||
class AssistantToolCall(BaseModel):
|
||||
# id: str
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: AssistantFunctionCall
|
||||
|
||||
|
||||
class AssistantToolCallDict(TypedDict):
|
||||
# id: str
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: AssistantFunctionCallDict
|
||||
|
||||
|
||||
class AssistantChatMessage(ChatMessage):
|
||||
role: Literal["assistant"] = "assistant"
|
||||
role: Literal[ChatMessage.Role.ASSISTANT] = ChatMessage.Role.ASSISTANT
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[list[AssistantToolCall]]
|
||||
tool_calls: Optional[list[AssistantToolCall]] = None
|
||||
|
||||
|
||||
class ToolResultMessage(ChatMessage):
|
||||
role: Literal[ChatMessage.Role.TOOL] = ChatMessage.Role.TOOL
|
||||
is_error: bool = False
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AssistantChatMessageDict(TypedDict, total=False):
|
||||
@@ -137,10 +153,35 @@ class CompletionModelFunction(BaseModel):
|
||||
|
||||
def fmt_line(self) -> str:
|
||||
params = ", ".join(
|
||||
f"{name}: {p.type.value}" for name, p in self.parameters.items()
|
||||
f"{name}{'?' if not p.required else ''}: " f"{p.typescript_type}"
|
||||
for name, p in self.parameters.items()
|
||||
)
|
||||
return f"{self.name}: {self.description}. Params: ({params})"
|
||||
|
||||
def validate_call(
|
||||
self, function_call: AssistantFunctionCall
|
||||
) -> tuple[bool, list["ValidationError"]]:
|
||||
"""
|
||||
Validates the given function call against the function's parameter specs
|
||||
|
||||
Returns:
|
||||
bool: Whether the given set of arguments is valid for this command
|
||||
list[ValidationError]: Issues with the set of arguments (if any)
|
||||
|
||||
Raises:
|
||||
ValueError: If the function_call doesn't call this function
|
||||
"""
|
||||
if function_call.name != self.name:
|
||||
raise ValueError(
|
||||
f"Can't validate {function_call.name} call using {self.name} spec"
|
||||
)
|
||||
|
||||
params_schema = JSONSchema(
|
||||
type=JSONSchema.Type.OBJECT,
|
||||
properties={name: spec for name, spec in self.parameters.items()},
|
||||
)
|
||||
return params_schema.validate_object(function_call.arguments)
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Struct for model information.
|
||||
@@ -187,39 +228,34 @@ class ModelProviderUsage(ProviderUsage):
|
||||
|
||||
completion_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
def update_usage(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
input_tokens_used: int,
|
||||
output_tokens_used: int = 0,
|
||||
) -> None:
|
||||
self.completion_tokens += model_response.completion_tokens_used
|
||||
self.prompt_tokens += model_response.prompt_tokens_used
|
||||
self.total_tokens += (
|
||||
model_response.completion_tokens_used + model_response.prompt_tokens_used
|
||||
)
|
||||
self.prompt_tokens += input_tokens_used
|
||||
self.completion_tokens += output_tokens_used
|
||||
|
||||
|
||||
class ModelProviderBudget(ProviderBudget):
|
||||
total_budget: float = UserConfigurable()
|
||||
total_cost: float
|
||||
remaining_budget: float
|
||||
usage: ModelProviderUsage
|
||||
usage: defaultdict[str, ModelProviderUsage] = defaultdict(ModelProviderUsage)
|
||||
|
||||
def update_usage_and_cost(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
model_info: ModelInfo,
|
||||
input_tokens_used: int,
|
||||
output_tokens_used: int = 0,
|
||||
) -> float:
|
||||
"""Update the usage and cost of the provider.
|
||||
|
||||
Returns:
|
||||
float: The (calculated) cost of the given model response.
|
||||
"""
|
||||
model_info = model_response.model_info
|
||||
self.usage.update_usage(model_response)
|
||||
self.usage[model_info.name].update_usage(input_tokens_used, output_tokens_used)
|
||||
incurred_cost = (
|
||||
model_response.completion_tokens_used * model_info.completion_token_cost
|
||||
+ model_response.prompt_tokens_used * model_info.prompt_token_cost
|
||||
output_tokens_used * model_info.completion_token_cost
|
||||
+ input_tokens_used * model_info.prompt_token_cost
|
||||
)
|
||||
self.total_cost += incurred_cost
|
||||
self.remaining_budget -= incurred_cost
|
||||
@@ -229,8 +265,8 @@ class ModelProviderBudget(ProviderBudget):
|
||||
class ModelProviderSettings(ProviderSettings):
|
||||
resource_type: ResourceType = ResourceType.MODEL
|
||||
configuration: ModelProviderConfiguration
|
||||
credentials: ModelProviderCredentials
|
||||
budget: ModelProviderBudget
|
||||
credentials: Optional[ModelProviderCredentials] = None
|
||||
budget: Optional[ModelProviderBudget] = None
|
||||
|
||||
|
||||
class ModelProvider(abc.ABC):
|
||||
@@ -238,8 +274,27 @@ class ModelProvider(abc.ABC):
|
||||
|
||||
default_settings: ClassVar[ModelProviderSettings]
|
||||
|
||||
_budget: Optional[ModelProviderBudget]
|
||||
_settings: ModelProviderSettings
|
||||
_configuration: ModelProviderConfiguration
|
||||
_credentials: Optional[ModelProviderCredentials] = None
|
||||
_budget: Optional[ModelProviderBudget] = None
|
||||
|
||||
_logger: logging.Logger
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Optional[ModelProviderSettings] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
if not settings:
|
||||
settings = self.default_settings.copy(deep=True)
|
||||
|
||||
self._settings = settings
|
||||
self._configuration = settings.configuration
|
||||
self._credentials = settings.credentials
|
||||
self._budget = settings.budget
|
||||
|
||||
self._logger = logger or logging.getLogger(self.__module__)
|
||||
|
||||
@abc.abstractmethod
|
||||
def count_tokens(self, text: str, model_name: str) -> int:
|
||||
@@ -284,7 +339,7 @@ class ModelTokenizer(Protocol):
|
||||
class EmbeddingModelInfo(ModelInfo):
|
||||
"""Struct for embedding model information."""
|
||||
|
||||
llm_service = ModelProviderService.EMBEDDING
|
||||
service: Literal[ModelProviderService.EMBEDDING] = ModelProviderService.EMBEDDING
|
||||
max_tokens: int
|
||||
embedding_dimensions: int
|
||||
|
||||
@@ -322,7 +377,7 @@ class EmbeddingModelProvider(ModelProvider):
|
||||
class ChatModelInfo(ModelInfo):
|
||||
"""Struct for language model information."""
|
||||
|
||||
llm_service = ModelProviderService.CHAT
|
||||
service: Literal[ModelProviderService.CHAT] = ModelProviderService.CHAT
|
||||
max_tokens: int
|
||||
has_function_call_api: bool = False
|
||||
|
||||
@@ -338,6 +393,10 @@ class ChatModelResponse(ModelResponse, Generic[_T]):
|
||||
|
||||
|
||||
class ChatModelProvider(ModelProvider):
|
||||
@abc.abstractmethod
|
||||
async def get_available_models(self) -> list[ChatModelInfo]:
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def count_message_tokens(
|
||||
self,
|
||||
@@ -353,6 +412,8 @@ class ChatModelProvider(ModelProvider):
|
||||
model_name: str,
|
||||
completion_parser: Callable[[AssistantChatMessage], _T] = lambda _: None,
|
||||
functions: Optional[list[CompletionModelFunction]] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
prefill_response: str = "",
|
||||
**kwargs,
|
||||
) -> ChatModelResponse[_T]:
|
||||
...
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
from typing import Any
|
||||
|
||||
from .schema import AssistantToolCall, CompletionModelFunction
|
||||
|
||||
|
||||
class InvalidFunctionCallError(Exception):
|
||||
def __init__(self, name: str, arguments: dict[str, Any], message: str):
|
||||
self.message = message
|
||||
self.name = name
|
||||
self.arguments = arguments
|
||||
super().__init__(message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Invalid function call for {self.name}: {self.message}"
|
||||
|
||||
|
||||
def validate_tool_calls(
|
||||
tool_calls: list[AssistantToolCall], functions: list[CompletionModelFunction]
|
||||
) -> list[InvalidFunctionCallError]:
|
||||
"""
|
||||
Validates a list of tool calls against a list of functions.
|
||||
|
||||
1. Tries to find a function matching each tool call
|
||||
2. If a matching function is found, validates the tool call's arguments,
|
||||
reporting any resulting errors
|
||||
2. If no matching function is found, an error "Unknown function X" is reported
|
||||
3. A list of all errors encountered during validation is returned
|
||||
|
||||
Params:
|
||||
tool_calls: A list of tool calls to validate.
|
||||
functions: A list of functions to validate against.
|
||||
|
||||
Returns:
|
||||
list[InvalidFunctionCallError]: All errors encountered during validation.
|
||||
"""
|
||||
errors: list[InvalidFunctionCallError] = []
|
||||
for tool_call in tool_calls:
|
||||
function_call = tool_call.function
|
||||
|
||||
if function := next(
|
||||
(f for f in functions if f.name == function_call.name),
|
||||
None,
|
||||
):
|
||||
is_valid, validation_errors = function.validate_call(function_call)
|
||||
if not is_valid:
|
||||
fmt_errors = [
|
||||
f"{'.'.join(str(p) for p in f.path)}: {f.message}"
|
||||
if f.path
|
||||
else f.message
|
||||
for f in validation_errors
|
||||
]
|
||||
errors.append(
|
||||
InvalidFunctionCallError(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
message=(
|
||||
"The set of arguments supplied is invalid:\n"
|
||||
+ "\n".join(fmt_errors)
|
||||
),
|
||||
)
|
||||
)
|
||||
else:
|
||||
errors.append(
|
||||
InvalidFunctionCallError(
|
||||
name=function_call.name,
|
||||
arguments=function_call.arguments,
|
||||
message=f"Unknown function {function_call.name}",
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
@@ -1,5 +1,6 @@
|
||||
import abc
|
||||
import enum
|
||||
import math
|
||||
|
||||
from pydantic import BaseModel, SecretBytes, SecretField, SecretStr
|
||||
|
||||
@@ -25,9 +26,9 @@ class ProviderUsage(SystemConfiguration, abc.ABC):
|
||||
|
||||
|
||||
class ProviderBudget(SystemConfiguration):
|
||||
total_budget: float = UserConfigurable()
|
||||
total_cost: float
|
||||
remaining_budget: float
|
||||
total_budget: float = UserConfigurable(math.inf)
|
||||
total_cost: float = 0
|
||||
remaining_budget: float = math.inf
|
||||
usage: ProviderUsage
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -4,10 +4,8 @@ from agent_protocol import StepHandler, StepResult
|
||||
|
||||
from autogpt.agents import Agent
|
||||
from autogpt.app.main import UserFeedback
|
||||
from autogpt.commands import COMMAND_CATEGORIES
|
||||
from autogpt.config import AIProfile, ConfigBuilder
|
||||
from autogpt.logs.helpers import user_friendly_output
|
||||
from autogpt.models.command_registry import CommandRegistry
|
||||
from autogpt.prompts.prompt import DEFAULT_TRIGGERING_PROMPT
|
||||
|
||||
|
||||
@@ -82,16 +80,16 @@ def bootstrap_agent(task, continuous_mode) -> Agent:
|
||||
config.logging.plain_console_output = True
|
||||
config.continuous_mode = continuous_mode
|
||||
config.temperature = 0
|
||||
command_registry = CommandRegistry.with_command_modules(COMMAND_CATEGORIES, config)
|
||||
config.memory_backend = "no_memory"
|
||||
ai_profile = AIProfile(
|
||||
ai_name="AutoGPT",
|
||||
ai_role="a multi-purpose AI assistant.",
|
||||
ai_goals=[task],
|
||||
)
|
||||
# FIXME this won't work - ai_profile and triggering_prompt is not a valid argument,
|
||||
# lacks file_storage, settings and llm_provider
|
||||
return Agent(
|
||||
command_registry=command_registry,
|
||||
ai_profile=ai_profile,
|
||||
config=config,
|
||||
legacy_config=config,
|
||||
triggering_prompt=DEFAULT_TRIGGERING_PROMPT,
|
||||
)
|
||||
|
||||
@@ -3,21 +3,25 @@ from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.core.prompting import ChatPrompt
|
||||
from autogpt.core.resource.model_providers import ChatMessage
|
||||
|
||||
SEPARATOR_LENGTH = 42
|
||||
|
||||
|
||||
def dump_prompt(prompt: "ChatPrompt") -> str:
|
||||
def dump_prompt(prompt: "ChatPrompt | list[ChatMessage]") -> str:
|
||||
def separator(text: str):
|
||||
half_sep_len = (SEPARATOR_LENGTH - 2 - len(text)) / 2
|
||||
return f"{floor(half_sep_len)*'-'} {text.upper()} {ceil(half_sep_len)*'-'}"
|
||||
|
||||
if not isinstance(prompt, list):
|
||||
prompt = prompt.messages
|
||||
|
||||
formatted_messages = "\n".join(
|
||||
[f"{separator(m.role)}\n{m.content}" for m in prompt.messages]
|
||||
[f"{separator(m.role)}\n{m.content}" for m in prompt]
|
||||
)
|
||||
return f"""
|
||||
============== {prompt.__class__.__name__} ==============
|
||||
Length: {len(prompt.messages)} messages
|
||||
Length: {len(prompt)} messages
|
||||
{formatted_messages}
|
||||
==========================================
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import enum
|
||||
from logging import Logger
|
||||
from textwrap import indent
|
||||
from typing import Literal, Optional
|
||||
from typing import Optional
|
||||
|
||||
from jsonschema import Draft7Validator
|
||||
from jsonschema import Draft7Validator, ValidationError
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@@ -58,10 +57,35 @@ class JSONSchema(BaseModel):
|
||||
|
||||
@staticmethod
|
||||
def from_dict(schema: dict) -> "JSONSchema":
|
||||
def resolve_references(schema: dict, definitions: dict) -> dict:
|
||||
"""
|
||||
Recursively resolve type $refs in the JSON schema with their definitions.
|
||||
"""
|
||||
if isinstance(schema, dict):
|
||||
if "$ref" in schema:
|
||||
ref_path = schema["$ref"].split("/")[
|
||||
2:
|
||||
] # Split and remove '#/definitions'
|
||||
ref_value = definitions
|
||||
for key in ref_path:
|
||||
ref_value = ref_value[key]
|
||||
return resolve_references(ref_value, definitions)
|
||||
else:
|
||||
return {
|
||||
k: resolve_references(v, definitions) for k, v in schema.items()
|
||||
}
|
||||
elif isinstance(schema, list):
|
||||
return [resolve_references(item, definitions) for item in schema]
|
||||
else:
|
||||
return schema
|
||||
|
||||
definitions = schema.get("definitions", {})
|
||||
schema = resolve_references(schema, definitions)
|
||||
|
||||
return JSONSchema(
|
||||
description=schema.get("description"),
|
||||
type=schema["type"],
|
||||
enum=schema["enum"] if "enum" in schema else None,
|
||||
enum=schema.get("enum"),
|
||||
items=JSONSchema.from_dict(schema["items"]) if "items" in schema else None,
|
||||
properties=JSONSchema.parse_properties(schema)
|
||||
if schema["type"] == "object"
|
||||
@@ -84,27 +108,24 @@ class JSONSchema(BaseModel):
|
||||
v.required = k in schema_node["required"]
|
||||
return properties
|
||||
|
||||
def validate_object(
|
||||
self, object: object, logger: Logger
|
||||
) -> tuple[Literal[True], None] | tuple[Literal[False], list]:
|
||||
def validate_object(self, object: object) -> tuple[bool, list[ValidationError]]:
|
||||
"""
|
||||
Validates a dictionary object against the JSONSchema.
|
||||
Validates an object or a value against the JSONSchema.
|
||||
|
||||
Params:
|
||||
object: The dictionary object to validate.
|
||||
object: The value/object to validate.
|
||||
schema (JSONSchema): The JSONSchema to validate against.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple where the first element is a boolean indicating whether the
|
||||
object is valid or not, and the second element is a list of errors found
|
||||
in the object, or None if the object is valid.
|
||||
bool: Indicates whether the given value or object is valid for the schema.
|
||||
list[ValidationError]: The issues with the value or object (if any).
|
||||
"""
|
||||
validator = Draft7Validator(self.to_dict())
|
||||
|
||||
if errors := sorted(validator.iter_errors(object), key=lambda e: e.path):
|
||||
return False, errors
|
||||
|
||||
return True, None
|
||||
return True, []
|
||||
|
||||
def to_typescript_object_interface(self, interface_name: str = "") -> str:
|
||||
if self.type != JSONSchema.Type.OBJECT:
|
||||
|
||||
93
autogpts/autogpt/autogpt/core/utils/json_utils.py
Normal file
93
autogpts/autogpt/autogpt/core/utils/json_utils.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import demjson3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def json_loads(json_str: str) -> Any:
|
||||
"""Parse a JSON string, tolerating minor syntax issues:
|
||||
- Missing, extra and trailing commas
|
||||
- Extraneous newlines and whitespace outside of string literals
|
||||
- Inconsistent spacing after colons and commas
|
||||
- Missing closing brackets or braces
|
||||
- Numbers: binary, hex, octal, trailing and prefixed decimal points
|
||||
- Different encodings
|
||||
- Surrounding markdown code block
|
||||
- Comments
|
||||
|
||||
Args:
|
||||
json_str: The JSON string to parse.
|
||||
|
||||
Returns:
|
||||
The parsed JSON object, same as built-in json.loads.
|
||||
"""
|
||||
# Remove possible code block
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
|
||||
json_result = demjson3.decode(json_str, return_errors=True)
|
||||
assert json_result is not None # by virtue of return_errors=True
|
||||
|
||||
if json_result.errors:
|
||||
logger.debug(
|
||||
"JSON parse errors:\n" + "\n".join(str(e) for e in json_result.errors)
|
||||
)
|
||||
|
||||
if json_result.object in (demjson3.syntax_error, demjson3.undefined):
|
||||
raise ValueError(
|
||||
f"Failed to parse JSON string: {json_str}", *json_result.errors
|
||||
)
|
||||
|
||||
return json_result.object
|
||||
|
||||
|
||||
def extract_dict_from_json(json_str: str) -> dict[str, Any]:
|
||||
# Sometimes the response includes the JSON in a code block with ```
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
else:
|
||||
# The string may contain JSON.
|
||||
json_pattern = r"{[\s\S]*}"
|
||||
match = re.search(json_pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group()
|
||||
|
||||
result = json_loads(json_str)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(
|
||||
f"Response '''{json_str}''' evaluated to non-dict value {repr(result)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def extract_list_from_json(json_str: str) -> list[Any]:
|
||||
# Sometimes the response includes the JSON in a code block with ```
|
||||
pattern = r"```(?:json|JSON)*([\s\S]*?)```"
|
||||
match = re.search(pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group(1).strip()
|
||||
else:
|
||||
# The string may contain JSON.
|
||||
json_pattern = r"\[[\s\S]*\]"
|
||||
match = re.search(json_pattern, json_str)
|
||||
|
||||
if match:
|
||||
json_str = match.group()
|
||||
|
||||
result = json_loads(json_str)
|
||||
if not isinstance(result, list):
|
||||
raise ValueError(
|
||||
f"Response '''{json_str}''' evaluated to non-list value {repr(result)}"
|
||||
)
|
||||
return result
|
||||
44
autogpts/autogpt/autogpt/file_storage/__init__.py
Normal file
44
autogpts/autogpt/autogpt/file_storage/__init__.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import enum
|
||||
from pathlib import Path
|
||||
|
||||
from .base import FileStorage
|
||||
|
||||
|
||||
class FileStorageBackendName(str, enum.Enum):
|
||||
LOCAL = "local"
|
||||
GCS = "gcs"
|
||||
S3 = "s3"
|
||||
|
||||
|
||||
def get_storage(
|
||||
backend: FileStorageBackendName,
|
||||
root_path: Path = ".",
|
||||
restrict_to_root: bool = True,
|
||||
) -> FileStorage:
|
||||
match backend:
|
||||
case FileStorageBackendName.LOCAL:
|
||||
from .local import FileStorageConfiguration, LocalFileStorage
|
||||
|
||||
config = FileStorageConfiguration.from_env()
|
||||
config.root = root_path
|
||||
config.restrict_to_root = restrict_to_root
|
||||
return LocalFileStorage(config)
|
||||
case FileStorageBackendName.S3:
|
||||
from .s3 import S3FileStorage, S3FileStorageConfiguration
|
||||
|
||||
config = S3FileStorageConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return S3FileStorage(config)
|
||||
case FileStorageBackendName.GCS:
|
||||
from .gcs import GCSFileStorage, GCSFileStorageConfiguration
|
||||
|
||||
config = GCSFileStorageConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return GCSFileStorage(config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FileStorage",
|
||||
"FileStorageBackendName",
|
||||
"get_storage",
|
||||
]
|
||||
204
autogpts/autogpt/autogpt/file_storage/base.py
Normal file
204
autogpts/autogpt/autogpt/file_storage/base.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
The FileStorage class provides an interface for interacting with a file storage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from io import IOBase, TextIOBase
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, BinaryIO, Callable, Literal, TextIO, overload
|
||||
|
||||
from autogpt.core.configuration.schema import SystemConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileStorageConfiguration(SystemConfiguration):
|
||||
restrict_to_root: bool = True
|
||||
root: Path = Path("/")
|
||||
|
||||
|
||||
class FileStorage(ABC):
|
||||
"""A class that represents a file storage."""
|
||||
|
||||
on_write_file: Callable[[Path], Any] | None = None
|
||||
"""
|
||||
Event hook, executed after writing a file.
|
||||
|
||||
Params:
|
||||
Path: The path of the file that was written, relative to the storage root.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def root(self) -> Path:
|
||||
"""The root path of the file storage."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict file access to within the storage's root path."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self) -> None:
|
||||
"""
|
||||
Calling `initialize()` should bring the storage to a ready-to-use state.
|
||||
For example, it can create the resource in which files will be stored, if it
|
||||
doesn't exist yet. E.g. a folder on disk, or an S3 Bucket.
|
||||
"""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["w", "r"] = "r",
|
||||
binary: Literal[False] = False,
|
||||
) -> TextIO | TextIOBase:
|
||||
"""Returns a readable text file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self,
|
||||
path: str | Path,
|
||||
mode: Literal["w", "r"] = "r",
|
||||
binary: Literal[True] = True,
|
||||
) -> BinaryIO | IOBase:
|
||||
"""Returns a readable binary file-like object representing the file."""
|
||||
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IO | IOBase:
|
||||
"""Returns a readable file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the storage as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[True] = True) -> bytes:
|
||||
"""Read a file in the storage as binary."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List all folders in a directory in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
|
||||
@abstractmethod
|
||||
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
|
||||
"""Create a new FileStorage with a subroot of the current storage."""
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
"""Get the full path for an item in the storage.
|
||||
|
||||
Parameters:
|
||||
relative_path: The relative path to resolve in the storage.
|
||||
|
||||
Returns:
|
||||
Path: The resolved path relative to the storage.
|
||||
"""
|
||||
return self._sanitize_path(relative_path)
|
||||
|
||||
def _sanitize_path(
|
||||
self,
|
||||
path: str | Path,
|
||||
) -> Path:
|
||||
"""Resolve the relative path within the given root if possible.
|
||||
|
||||
Parameters:
|
||||
relative_path: The relative path to resolve.
|
||||
|
||||
Returns:
|
||||
Path: The resolved path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path is absolute and a root is provided.
|
||||
ValueError: If the path is outside the root and the root is restricted.
|
||||
"""
|
||||
|
||||
# Posix systems disallow null bytes in paths. Windows is agnostic about it.
|
||||
# Do an explicit check here for all sorts of null byte representations.
|
||||
if "\0" in str(path):
|
||||
raise ValueError("Embedded null byte")
|
||||
|
||||
logger.debug(f"Resolving path '{path}' in storage '{self.root}'")
|
||||
|
||||
relative_path = Path(path)
|
||||
|
||||
# Allow absolute paths if they are contained in the storage.
|
||||
if (
|
||||
relative_path.is_absolute()
|
||||
and self.restrict_to_root
|
||||
and not relative_path.is_relative_to(self.root)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Attempted to access absolute path '{relative_path}' "
|
||||
f"in storage '{self.root}'"
|
||||
)
|
||||
|
||||
full_path = self.root / relative_path
|
||||
if self.is_local:
|
||||
full_path = full_path.resolve()
|
||||
else:
|
||||
full_path = Path(os.path.normpath(full_path))
|
||||
|
||||
logger.debug(f"Joined paths as '{full_path}'")
|
||||
|
||||
if self.restrict_to_root and not full_path.is_relative_to(self.root):
|
||||
raise ValueError(
|
||||
f"Attempted to access path '{full_path}' "
|
||||
f"outside of storage '{self.root}'."
|
||||
)
|
||||
|
||||
return full_path
|
||||
213
autogpts/autogpt/autogpt/file_storage/gcs.py
Normal file
213
autogpts/autogpt/autogpt/file_storage/gcs.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""
|
||||
The GCSWorkspace class provides an interface for interacting with a file workspace, and
|
||||
stores the files in a Google Cloud Storage bucket.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from io import IOBase
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
|
||||
from autogpt.core.configuration.schema import UserConfigurable
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCSFileStorageConfiguration(FileStorageConfiguration):
|
||||
bucket: str = UserConfigurable("autogpt", from_env="STORAGE_BUCKET")
|
||||
|
||||
|
||||
class GCSFileStorage(FileStorage):
|
||||
"""A class that represents a Google Cloud Storage."""
|
||||
|
||||
_bucket: storage.Bucket
|
||||
|
||||
def __init__(self, config: GCSFileStorageConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
# Add / at the beginning of the root path
|
||||
if not self._root.is_absolute():
|
||||
self._root = Path("/").joinpath(self._root)
|
||||
|
||||
self._gcs = storage.Client()
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file storage."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
return False
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._bucket = self._gcs.get_bucket(self._bucket_name)
|
||||
except NotFound:
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._gcs.create_bucket(self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
# We set GCS root with "/" at the beginning
|
||||
# but relative_to("/") will remove it
|
||||
# because we don't actually want it in the storage filenames
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def _get_blob(self, path: str | Path) -> storage.Blob:
|
||||
path = self.get_path(path)
|
||||
return self._bucket.blob(str(path))
|
||||
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IOBase:
|
||||
"""Open a file in the storage."""
|
||||
blob = self._get_blob(path)
|
||||
blob.reload() # pin revision number to prevent version mixing while reading
|
||||
return blob.open(f"{mode}b" if binary else mode)
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
return self.open_file(path, "r", binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
blob = self._get_blob(path)
|
||||
|
||||
blob.upload_from_string(
|
||||
data=content,
|
||||
content_type=(
|
||||
"text/plain"
|
||||
if type(content) is str
|
||||
# TODO: get MIME type from file extension or binary content
|
||||
else "application/octet-stream"
|
||||
),
|
||||
)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
path = self.get_path(path)
|
||||
return [
|
||||
Path(blob.name).relative_to(path)
|
||||
for blob in self._bucket.list_blobs(
|
||||
prefix=f"{path}/" if path != Path(".") else None
|
||||
)
|
||||
]
|
||||
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List 'directories' directly in a given path or recursively in the storage."""
|
||||
path = self.get_path(path)
|
||||
folder_names = set()
|
||||
|
||||
# List objects with the specified prefix and delimiter
|
||||
for blob in self._bucket.list_blobs(prefix=path):
|
||||
# Remove path prefix and the object name (last part)
|
||||
folder = Path(blob.name).relative_to(path).parent
|
||||
if not folder or folder == Path("."):
|
||||
continue
|
||||
# For non-recursive, only add the first level of folders
|
||||
if not recursive:
|
||||
folder_names.add(folder.parts[0])
|
||||
else:
|
||||
# For recursive, need to add all nested folders
|
||||
for i in range(len(folder.parts)):
|
||||
folder_names.add("/".join(folder.parts[: i + 1]))
|
||||
|
||||
return [Path(f) for f in folder_names]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
path = self.get_path(path)
|
||||
blob = self._bucket.blob(str(path))
|
||||
blob.delete()
|
||||
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
# Since GCS does not have directories, we don't need to do anything
|
||||
pass
|
||||
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in GCS storage."""
|
||||
path = self.get_path(path)
|
||||
# Check for exact blob match (file)
|
||||
blob = self._bucket.blob(str(path))
|
||||
if blob.exists():
|
||||
return True
|
||||
# Check for any blobs with prefix (folder)
|
||||
prefix = f"{str(path).rstrip('/')}/"
|
||||
blobs = self._bucket.list_blobs(prefix=prefix, max_results=1)
|
||||
return next(blobs, None) is not None
|
||||
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
# GCS does not have directories, so we don't need to do anything
|
||||
pass
|
||||
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
old_path = self.get_path(old_path)
|
||||
new_path = self.get_path(new_path)
|
||||
blob = self._bucket.blob(str(old_path))
|
||||
# If the blob with exact name exists, rename it
|
||||
if blob.exists():
|
||||
self._bucket.rename_blob(blob, new_name=str(new_path))
|
||||
return
|
||||
# Otherwise, rename all blobs with the prefix (folder)
|
||||
for blob in self._bucket.list_blobs(prefix=f"{old_path}/"):
|
||||
new_name = str(blob.name).replace(str(old_path), str(new_path), 1)
|
||||
self._bucket.rename_blob(blob, new_name=new_name)
|
||||
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
source = self.get_path(source)
|
||||
destination = self.get_path(destination)
|
||||
# If the source is a file, copy it
|
||||
if self._bucket.blob(str(source)).exists():
|
||||
self._bucket.copy_blob(
|
||||
self._bucket.blob(str(source)), self._bucket, str(destination)
|
||||
)
|
||||
return
|
||||
# Otherwise, copy all blobs with the prefix (folder)
|
||||
for blob in self._bucket.list_blobs(prefix=f"{source}/"):
|
||||
new_name = str(blob.name).replace(str(source), str(destination), 1)
|
||||
self._bucket.copy_blob(blob, self._bucket, new_name)
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> GCSFileStorage:
|
||||
"""Create a new GCSFileStorage with a subroot of the current storage."""
|
||||
file_storage = GCSFileStorage(
|
||||
GCSFileStorageConfiguration(
|
||||
root=Path("/").joinpath(self.get_path(subroot)),
|
||||
bucket=self._bucket_name,
|
||||
)
|
||||
)
|
||||
file_storage._gcs = self._gcs
|
||||
file_storage._bucket = self._bucket
|
||||
return file_storage
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
139
autogpts/autogpt/autogpt/file_storage/local.py
Normal file
139
autogpts/autogpt/autogpt/file_storage/local.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
The LocalFileStorage class implements a FileStorage that works with local files.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import IO, Literal
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalFileStorage(FileStorage):
|
||||
"""A class that represents a file storage."""
|
||||
|
||||
def __init__(self, config: FileStorageConfiguration):
|
||||
self._root = config.root.resolve()
|
||||
self._restrict_to_root = config.restrict_to_root
|
||||
self.make_dir(self.root)
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file storage."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return self._restrict_to_root
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
return True
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IO:
|
||||
"""Open a file in the storage."""
|
||||
return self._open_file(path, f"{mode}b" if binary else mode)
|
||||
|
||||
def _open_file(self, path: str | Path, mode: str) -> IO:
|
||||
full_path = self.get_path(path)
|
||||
return open(full_path, mode) # type: ignore
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
with self._open_file(path, "rb" if binary else "r") as file:
|
||||
return file.read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
with self._open_file(path, "wb" if type(content) is bytes else "w") as file:
|
||||
file.write(content)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
path = self.get_path(path)
|
||||
return [file.relative_to(path) for file in path.rglob("*") if file.is_file()]
|
||||
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List directories directly in a given path or recursively."""
|
||||
path = self.get_path(path)
|
||||
if recursive:
|
||||
return [
|
||||
folder.relative_to(path)
|
||||
for folder in path.rglob("*")
|
||||
if folder.is_dir()
|
||||
]
|
||||
else:
|
||||
return [
|
||||
folder.relative_to(path) for folder in path.iterdir() if folder.is_dir()
|
||||
]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.unlink()
|
||||
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.rmdir()
|
||||
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in the storage."""
|
||||
return self.get_path(path).exists()
|
||||
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
old_path = self.get_path(old_path)
|
||||
new_path = self.get_path(new_path)
|
||||
old_path.rename(new_path)
|
||||
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
source = self.get_path(source)
|
||||
destination = self.get_path(destination)
|
||||
if source.is_file():
|
||||
destination.write_bytes(source.read_bytes())
|
||||
else:
|
||||
destination.mkdir(exist_ok=True, parents=True)
|
||||
for file in source.rglob("*"):
|
||||
if file.is_file():
|
||||
target = destination / file.relative_to(source)
|
||||
target.parent.mkdir(exist_ok=True, parents=True)
|
||||
target.write_bytes(file.read_bytes())
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> FileStorage:
|
||||
"""Create a new LocalFileStorage with a subroot of the current storage."""
|
||||
return LocalFileStorage(
|
||||
FileStorageConfiguration(
|
||||
root=self.get_path(subroot),
|
||||
restrict_to_root=self.restrict_to_root,
|
||||
)
|
||||
)
|
||||
265
autogpts/autogpt/autogpt/file_storage/s3.py
Normal file
265
autogpts/autogpt/autogpt/file_storage/s3.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
The S3Workspace class provides an interface for interacting with a file workspace, and
|
||||
stores the files in an S3 bucket.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
from io import IOBase, TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import boto3
|
||||
import botocore.exceptions
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.core.configuration.schema import UserConfigurable
|
||||
|
||||
from .base import FileStorage, FileStorageConfiguration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import mypy_boto3_s3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S3FileStorageConfiguration(FileStorageConfiguration):
|
||||
bucket: str = UserConfigurable("autogpt", from_env="STORAGE_BUCKET")
|
||||
s3_endpoint_url: Optional[SecretStr] = UserConfigurable(from_env="S3_ENDPOINT_URL")
|
||||
|
||||
|
||||
class S3FileStorage(FileStorage):
|
||||
"""A class that represents an S3 storage."""
|
||||
|
||||
_bucket: mypy_boto3_s3.service_resource.Bucket
|
||||
|
||||
def __init__(self, config: S3FileStorageConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
# Add / at the beginning of the root path
|
||||
if not self._root.is_absolute():
|
||||
self._root = Path("/").joinpath(self._root)
|
||||
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
|
||||
self._s3 = boto3.resource(
|
||||
"s3",
|
||||
endpoint_url=(
|
||||
config.s3_endpoint_url.get_secret_value()
|
||||
if config.s3_endpoint_url
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file storage."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self):
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_local(self) -> bool:
|
||||
"""Whether the storage is local (i.e. on the same machine, not cloud-based)."""
|
||||
return False
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._s3.meta.client.head_bucket(Bucket=self._bucket_name)
|
||||
self._bucket = self._s3.Bucket(self._bucket_name)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if "(404)" not in str(e):
|
||||
raise
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._s3.create_bucket(Bucket=self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
# We set S3 root with "/" at the beginning
|
||||
# but relative_to("/") will remove it
|
||||
# because we don't actually want it in the storage filenames
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def _get_obj(self, path: str | Path) -> mypy_boto3_s3.service_resource.Object:
|
||||
"""Get an S3 object."""
|
||||
path = self.get_path(path)
|
||||
obj = self._bucket.Object(str(path))
|
||||
with contextlib.suppress(botocore.exceptions.ClientError):
|
||||
obj.load()
|
||||
return obj
|
||||
|
||||
def open_file(
|
||||
self, path: str | Path, mode: Literal["w", "r"] = "r", binary: bool = False
|
||||
) -> IOBase:
|
||||
"""Open a file in the storage."""
|
||||
obj = self._get_obj(path)
|
||||
return obj.get()["Body"] if binary else TextIOWrapper(obj.get()["Body"])
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the storage."""
|
||||
return self.open_file(path, binary=binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the storage."""
|
||||
obj = self._get_obj(path)
|
||||
obj.put(Body=content)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list_files(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the storage."""
|
||||
path = self.get_path(path)
|
||||
if path == Path("."): # root level of bucket
|
||||
return [Path(obj.key) for obj in self._bucket.objects.all()]
|
||||
else:
|
||||
return [
|
||||
Path(obj.key).relative_to(path)
|
||||
for obj in self._bucket.objects.filter(Prefix=f"{path}/")
|
||||
]
|
||||
|
||||
def list_folders(
|
||||
self, path: str | Path = ".", recursive: bool = False
|
||||
) -> list[Path]:
|
||||
"""List 'directories' directly in a given path or recursively in the storage."""
|
||||
path = self.get_path(path)
|
||||
folder_names = set()
|
||||
|
||||
# List objects with the specified prefix and delimiter
|
||||
for obj_summary in self._bucket.objects.filter(Prefix=str(path)):
|
||||
# Remove path prefix and the object name (last part)
|
||||
folder = Path(obj_summary.key).relative_to(path).parent
|
||||
if not folder or folder == Path("."):
|
||||
continue
|
||||
# For non-recursive, only add the first level of folders
|
||||
if not recursive:
|
||||
folder_names.add(folder.parts[0])
|
||||
else:
|
||||
# For recursive, need to add all nested folders
|
||||
for i in range(len(folder.parts)):
|
||||
folder_names.add("/".join(folder.parts[: i + 1]))
|
||||
|
||||
return [Path(f) for f in folder_names]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the storage."""
|
||||
path = self.get_path(path)
|
||||
obj = self._s3.Object(self._bucket_name, str(path))
|
||||
obj.delete()
|
||||
|
||||
def delete_dir(self, path: str | Path) -> None:
|
||||
"""Delete an empty folder in the storage."""
|
||||
# S3 does not have directories, so we don't need to do anything
|
||||
pass
|
||||
|
||||
def exists(self, path: str | Path) -> bool:
|
||||
"""Check if a file or folder exists in S3 storage."""
|
||||
path = self.get_path(path)
|
||||
try:
|
||||
# Check for exact object match (file)
|
||||
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=str(path))
|
||||
return True
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
|
||||
# If the object does not exist,
|
||||
# check for objects with the prefix (folder)
|
||||
prefix = f"{str(path).rstrip('/')}/"
|
||||
objs = list(self._bucket.objects.filter(Prefix=prefix, MaxKeys=1))
|
||||
return len(objs) > 0 # True if any objects exist with the prefix
|
||||
else:
|
||||
raise # Re-raise for any other client errors
|
||||
|
||||
def make_dir(self, path: str | Path) -> None:
|
||||
"""Create a directory in the storage if doesn't exist."""
|
||||
# S3 does not have directories, so we don't need to do anything
|
||||
pass
|
||||
|
||||
def rename(self, old_path: str | Path, new_path: str | Path) -> None:
|
||||
"""Rename a file or folder in the storage."""
|
||||
old_path = str(self.get_path(old_path))
|
||||
new_path = str(self.get_path(new_path))
|
||||
|
||||
try:
|
||||
# If file exists, rename it
|
||||
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=old_path)
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": old_path},
|
||||
Bucket=self._bucket_name,
|
||||
Key=new_path,
|
||||
)
|
||||
self._s3.meta.client.delete_object(Bucket=self._bucket_name, Key=old_path)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
|
||||
# If the object does not exist,
|
||||
# it may be a folder
|
||||
prefix = f"{old_path.rstrip('/')}/"
|
||||
objs = list(self._bucket.objects.filter(Prefix=prefix))
|
||||
for obj in objs:
|
||||
new_key = new_path + obj.key[len(old_path) :]
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
|
||||
Bucket=self._bucket_name,
|
||||
Key=new_key,
|
||||
)
|
||||
self._s3.meta.client.delete_object(
|
||||
Bucket=self._bucket_name, Key=obj.key
|
||||
)
|
||||
else:
|
||||
raise # Re-raise for any other client errors
|
||||
|
||||
def copy(self, source: str | Path, destination: str | Path) -> None:
|
||||
"""Copy a file or folder with all contents in the storage."""
|
||||
source = str(self.get_path(source))
|
||||
destination = str(self.get_path(destination))
|
||||
|
||||
try:
|
||||
# If source is a file, copy it
|
||||
self._s3.meta.client.head_object(Bucket=self._bucket_name, Key=source)
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": source},
|
||||
Bucket=self._bucket_name,
|
||||
Key=destination,
|
||||
)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if int(e.response["ResponseMetadata"]["HTTPStatusCode"]) == 404:
|
||||
# If the object does not exist,
|
||||
# it may be a folder
|
||||
prefix = f"{source.rstrip('/')}/"
|
||||
objs = list(self._bucket.objects.filter(Prefix=prefix))
|
||||
for obj in objs:
|
||||
new_key = destination + obj.key[len(source) :]
|
||||
self._s3.meta.client.copy_object(
|
||||
CopySource={"Bucket": self._bucket_name, "Key": obj.key},
|
||||
Bucket=self._bucket_name,
|
||||
Key=new_key,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
def clone_with_subroot(self, subroot: str | Path) -> S3FileStorage:
|
||||
"""Create a new S3FileStorage with a subroot of the current storage."""
|
||||
file_storage = S3FileStorage(
|
||||
S3FileStorageConfiguration(
|
||||
bucket=self._bucket_name,
|
||||
root=Path("/").joinpath(self.get_path(subroot)),
|
||||
s3_endpoint_url=self._s3.meta.client.meta.endpoint_url,
|
||||
)
|
||||
)
|
||||
file_storage._s3 = self._s3
|
||||
file_storage._bucket = self._bucket
|
||||
return file_storage
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
@@ -1,46 +0,0 @@
|
||||
import enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .base import FileWorkspace
|
||||
|
||||
|
||||
class FileWorkspaceBackendName(str, enum.Enum):
|
||||
LOCAL = "local"
|
||||
GCS = "gcs"
|
||||
S3 = "s3"
|
||||
|
||||
|
||||
def get_workspace(
|
||||
backend: FileWorkspaceBackendName, *, id: str = "", root_path: Optional[Path] = None
|
||||
) -> FileWorkspace:
|
||||
assert bool(root_path) != bool(id), "Specify root_path or id to get workspace"
|
||||
if root_path is None:
|
||||
root_path = Path(f"/workspaces/{id}")
|
||||
|
||||
match backend:
|
||||
case FileWorkspaceBackendName.LOCAL:
|
||||
from .local import FileWorkspaceConfiguration, LocalFileWorkspace
|
||||
|
||||
config = FileWorkspaceConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return LocalFileWorkspace(config)
|
||||
case FileWorkspaceBackendName.S3:
|
||||
from .s3 import S3FileWorkspace, S3FileWorkspaceConfiguration
|
||||
|
||||
config = S3FileWorkspaceConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return S3FileWorkspace(config)
|
||||
case FileWorkspaceBackendName.GCS:
|
||||
from .gcs import GCSFileWorkspace, GCSFileWorkspaceConfiguration
|
||||
|
||||
config = GCSFileWorkspaceConfiguration.from_env()
|
||||
config.root = root_path
|
||||
return GCSFileWorkspace(config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FileWorkspace",
|
||||
"FileWorkspaceBackendName",
|
||||
"get_workspace",
|
||||
]
|
||||
@@ -1,164 +0,0 @@
|
||||
"""
|
||||
The FileWorkspace class provides an interface for interacting with a file workspace.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from io import IOBase, TextIOBase
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, BinaryIO, Callable, Literal, Optional, TextIO, overload
|
||||
|
||||
from autogpt.core.configuration.schema import SystemConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileWorkspaceConfiguration(SystemConfiguration):
|
||||
restrict_to_root: bool = True
|
||||
root: Path = Path("/")
|
||||
|
||||
|
||||
class FileWorkspace(ABC):
|
||||
"""A class that represents a file workspace."""
|
||||
|
||||
on_write_file: Callable[[Path], Any] | None = None
|
||||
"""
|
||||
Event hook, executed after writing a file.
|
||||
|
||||
Params:
|
||||
Path: The path of the file that was written, relative to the workspace root.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def root(self) -> Path:
|
||||
"""The root path of the file workspace."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict file access to within the workspace's root path."""
|
||||
|
||||
@abstractmethod
|
||||
def initialize(self) -> None:
|
||||
"""
|
||||
Calling `initialize()` should bring the workspace to a ready-to-use state.
|
||||
For example, it can create the resource in which files will be stored, if it
|
||||
doesn't exist yet. E.g. a folder on disk, or an S3 Bucket.
|
||||
"""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self, path: str | Path, binary: Literal[False] = False
|
||||
) -> TextIO | TextIOBase:
|
||||
"""Returns a readable text file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def open_file(
|
||||
self, path: str | Path, binary: Literal[True] = True
|
||||
) -> BinaryIO | IOBase:
|
||||
"""Returns a readable binary file-like object representing the file."""
|
||||
|
||||
@abstractmethod
|
||||
def open_file(self, path: str | Path, binary: bool = False) -> IO | IOBase:
|
||||
"""Returns a readable file-like object representing the file."""
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[False] = False) -> str:
|
||||
"""Read a file in the workspace as text."""
|
||||
...
|
||||
|
||||
@overload
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: Literal[True] = True) -> bytes:
|
||||
"""Read a file in the workspace as binary."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the workspace."""
|
||||
|
||||
@abstractmethod
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the workspace."""
|
||||
|
||||
@abstractmethod
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the workspace."""
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
"""Get the full path for an item in the workspace.
|
||||
|
||||
Parameters:
|
||||
relative_path: The relative path to resolve in the workspace.
|
||||
|
||||
Returns:
|
||||
Path: The resolved path relative to the workspace.
|
||||
"""
|
||||
return self._sanitize_path(relative_path, self.root)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_path(
|
||||
relative_path: str | Path,
|
||||
root: Optional[str | Path] = None,
|
||||
restrict_to_root: bool = True,
|
||||
) -> Path:
|
||||
"""Resolve the relative path within the given root if possible.
|
||||
|
||||
Parameters:
|
||||
relative_path: The relative path to resolve.
|
||||
root: The root path to resolve the relative path within.
|
||||
restrict_to_root: Whether to restrict the path to the root.
|
||||
|
||||
Returns:
|
||||
Path: The resolved path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the path is absolute and a root is provided.
|
||||
ValueError: If the path is outside the root and the root is restricted.
|
||||
"""
|
||||
|
||||
# Posix systems disallow null bytes in paths. Windows is agnostic about it.
|
||||
# Do an explicit check here for all sorts of null byte representations.
|
||||
|
||||
if "\0" in str(relative_path) or "\0" in str(root):
|
||||
raise ValueError("embedded null byte")
|
||||
|
||||
if root is None:
|
||||
return Path(relative_path).resolve()
|
||||
|
||||
logger.debug(f"Resolving path '{relative_path}' in workspace '{root}'")
|
||||
|
||||
root, relative_path = Path(root).resolve(), Path(relative_path)
|
||||
|
||||
logger.debug(f"Resolved root as '{root}'")
|
||||
|
||||
# Allow absolute paths if they are contained in the workspace.
|
||||
if (
|
||||
relative_path.is_absolute()
|
||||
and restrict_to_root
|
||||
and not relative_path.is_relative_to(root)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Attempted to access absolute path '{relative_path}' "
|
||||
f"in workspace '{root}'."
|
||||
)
|
||||
|
||||
full_path = root.joinpath(relative_path).resolve()
|
||||
|
||||
logger.debug(f"Joined paths as '{full_path}'")
|
||||
|
||||
if restrict_to_root and not full_path.is_relative_to(root):
|
||||
raise ValueError(
|
||||
f"Attempted to access path '{full_path}' outside of workspace '{root}'."
|
||||
)
|
||||
|
||||
return full_path
|
||||
@@ -1,113 +0,0 @@
|
||||
"""
|
||||
The GCSWorkspace class provides an interface for interacting with a file workspace, and
|
||||
stores the files in a Google Cloud Storage bucket.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from io import IOBase
|
||||
from pathlib import Path
|
||||
|
||||
from google.cloud import storage
|
||||
from google.cloud.exceptions import NotFound
|
||||
|
||||
from autogpt.core.configuration.schema import UserConfigurable
|
||||
|
||||
from .base import FileWorkspace, FileWorkspaceConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GCSFileWorkspaceConfiguration(FileWorkspaceConfiguration):
|
||||
bucket: str = UserConfigurable("autogpt", from_env="WORKSPACE_STORAGE_BUCKET")
|
||||
|
||||
|
||||
class GCSFileWorkspace(FileWorkspace):
|
||||
"""A class that represents a Google Cloud Storage workspace."""
|
||||
|
||||
_bucket: storage.Bucket
|
||||
|
||||
def __init__(self, config: GCSFileWorkspaceConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
assert self._root.is_absolute()
|
||||
|
||||
self._gcs = storage.Client()
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file workspace."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return True
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._bucket = self._gcs.get_bucket(self._bucket_name)
|
||||
except NotFound:
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._gcs.create_bucket(self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def _get_blob(self, path: str | Path) -> storage.Blob:
|
||||
path = self.get_path(path)
|
||||
return self._bucket.blob(str(path))
|
||||
|
||||
def open_file(self, path: str | Path, binary: bool = False) -> IOBase:
|
||||
"""Open a file in the workspace."""
|
||||
blob = self._get_blob(path)
|
||||
blob.reload() # pin revision number to prevent version mixing while reading
|
||||
return blob.open("rb" if binary else "r")
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the workspace."""
|
||||
return self.open_file(path, binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the workspace."""
|
||||
blob = self._get_blob(path)
|
||||
|
||||
blob.upload_from_string(
|
||||
data=content,
|
||||
content_type=(
|
||||
"text/plain"
|
||||
if type(content) is str
|
||||
# TODO: get MIME type from file extension or binary content
|
||||
else "application/octet-stream"
|
||||
),
|
||||
)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
return [
|
||||
Path(blob.name).relative_to(path)
|
||||
for blob in self._bucket.list_blobs(
|
||||
prefix=f"{path}/" if path != Path(".") else None
|
||||
)
|
||||
]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the workspace."""
|
||||
path = self.get_path(path)
|
||||
blob = self._bucket.blob(str(path))
|
||||
blob.delete()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
@@ -1,71 +0,0 @@
|
||||
"""
|
||||
The LocalFileWorkspace class implements a FileWorkspace that works with local files.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import IO
|
||||
|
||||
from .base import FileWorkspace, FileWorkspaceConfiguration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalFileWorkspace(FileWorkspace):
|
||||
"""A class that represents a file workspace."""
|
||||
|
||||
def __init__(self, config: FileWorkspaceConfiguration):
|
||||
self._root = self._sanitize_path(config.root)
|
||||
self._restrict_to_root = config.restrict_to_root
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file workspace."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self) -> bool:
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return self._restrict_to_root
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
def open_file(self, path: str | Path, binary: bool = False) -> IO:
|
||||
"""Open a file in the workspace."""
|
||||
return self._open_file(path, "rb" if binary else "r")
|
||||
|
||||
def _open_file(self, path: str | Path, mode: str = "r") -> IO:
|
||||
full_path = self.get_path(path)
|
||||
return open(full_path, mode) # type: ignore
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the workspace."""
|
||||
with self._open_file(path, "rb" if binary else "r") as file:
|
||||
return file.read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the workspace."""
|
||||
with self._open_file(path, "wb" if type(content) is bytes else "w") as file:
|
||||
file.write(content)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
return [file.relative_to(path) for file in path.rglob("*") if file.is_file()]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the workspace."""
|
||||
full_path = self.get_path(path)
|
||||
full_path.unlink()
|
||||
@@ -1,128 +0,0 @@
|
||||
"""
|
||||
The S3Workspace class provides an interface for interacting with a file workspace, and
|
||||
stores the files in an S3 bucket.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from io import IOBase, TextIOWrapper
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import boto3
|
||||
import botocore.exceptions
|
||||
from pydantic import SecretStr
|
||||
|
||||
from autogpt.core.configuration.schema import UserConfigurable
|
||||
|
||||
from .base import FileWorkspace, FileWorkspaceConfiguration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import mypy_boto3_s3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S3FileWorkspaceConfiguration(FileWorkspaceConfiguration):
|
||||
bucket: str = UserConfigurable("autogpt", from_env="WORKSPACE_STORAGE_BUCKET")
|
||||
s3_endpoint_url: Optional[SecretStr] = UserConfigurable(
|
||||
from_env=lambda: SecretStr(v) if (v := os.getenv("S3_ENDPOINT_URL")) else None
|
||||
)
|
||||
|
||||
|
||||
class S3FileWorkspace(FileWorkspace):
|
||||
"""A class that represents an S3 workspace."""
|
||||
|
||||
_bucket: mypy_boto3_s3.service_resource.Bucket
|
||||
|
||||
def __init__(self, config: S3FileWorkspaceConfiguration):
|
||||
self._bucket_name = config.bucket
|
||||
self._root = config.root
|
||||
assert self._root.is_absolute()
|
||||
|
||||
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
|
||||
self._s3 = boto3.resource(
|
||||
"s3",
|
||||
endpoint_url=config.s3_endpoint_url.get_secret_value()
|
||||
if config.s3_endpoint_url
|
||||
else None,
|
||||
)
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""The root directory of the file workspace."""
|
||||
return self._root
|
||||
|
||||
@property
|
||||
def restrict_to_root(self):
|
||||
"""Whether to restrict generated paths to the root."""
|
||||
return True
|
||||
|
||||
def initialize(self) -> None:
|
||||
logger.debug(f"Initializing {repr(self)}...")
|
||||
try:
|
||||
self._s3.meta.client.head_bucket(Bucket=self._bucket_name)
|
||||
self._bucket = self._s3.Bucket(self._bucket_name)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if "(404)" not in str(e):
|
||||
raise
|
||||
logger.info(f"Bucket '{self._bucket_name}' does not exist; creating it...")
|
||||
self._bucket = self._s3.create_bucket(Bucket=self._bucket_name)
|
||||
|
||||
def get_path(self, relative_path: str | Path) -> Path:
|
||||
return super().get_path(relative_path).relative_to("/")
|
||||
|
||||
def _get_obj(self, path: str | Path) -> mypy_boto3_s3.service_resource.Object:
|
||||
"""Get an S3 object."""
|
||||
path = self.get_path(path)
|
||||
obj = self._bucket.Object(str(path))
|
||||
with contextlib.suppress(botocore.exceptions.ClientError):
|
||||
obj.load()
|
||||
return obj
|
||||
|
||||
def open_file(self, path: str | Path, binary: bool = False) -> IOBase:
|
||||
"""Open a file in the workspace."""
|
||||
obj = self._get_obj(path)
|
||||
return obj.get()["Body"] if binary else TextIOWrapper(obj.get()["Body"])
|
||||
|
||||
def read_file(self, path: str | Path, binary: bool = False) -> str | bytes:
|
||||
"""Read a file in the workspace."""
|
||||
return self.open_file(path, binary).read()
|
||||
|
||||
async def write_file(self, path: str | Path, content: str | bytes) -> None:
|
||||
"""Write to a file in the workspace."""
|
||||
obj = self._get_obj(path)
|
||||
obj.put(Body=content)
|
||||
|
||||
if self.on_write_file:
|
||||
path = Path(path)
|
||||
if path.is_absolute():
|
||||
path = path.relative_to(self.root)
|
||||
res = self.on_write_file(path)
|
||||
if inspect.isawaitable(res):
|
||||
await res
|
||||
|
||||
def list(self, path: str | Path = ".") -> list[Path]:
|
||||
"""List all files (recursively) in a directory in the workspace."""
|
||||
path = self.get_path(path)
|
||||
if path == Path("."): # root level of bucket
|
||||
return [Path(obj.key) for obj in self._bucket.objects.all()]
|
||||
else:
|
||||
return [
|
||||
Path(obj.key).relative_to(path)
|
||||
for obj in self._bucket.objects.filter(Prefix=f"{path}/")
|
||||
]
|
||||
|
||||
def delete_file(self, path: str | Path) -> None:
|
||||
"""Delete a file in the workspace."""
|
||||
path = self.get_path(path)
|
||||
obj = self._s3.Object(self._bucket_name, str(path))
|
||||
obj.delete()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{__class__.__name__}(bucket='{self._bucket_name}', root={self._root})"
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Utilities for the json_fixes package."""
|
||||
import ast
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_dict_from_response(response_content: str) -> dict[str, Any]:
|
||||
# Sometimes the response includes the JSON in a code block with ```
|
||||
pattern = r"```([\s\S]*?)```"
|
||||
match = re.search(pattern, response_content)
|
||||
|
||||
if match:
|
||||
response_content = match.group(1).strip()
|
||||
# Remove language names in code blocks
|
||||
response_content = response_content.lstrip("json")
|
||||
else:
|
||||
# The string may contain JSON.
|
||||
json_pattern = r"{[\s\S]*}"
|
||||
match = re.search(json_pattern, response_content)
|
||||
|
||||
if match:
|
||||
response_content = match.group()
|
||||
|
||||
# Response content comes from OpenAI as a Python `str(content_dict)`.
|
||||
# `literal_eval` does the reverse of `str(dict)`.
|
||||
result = ast.literal_eval(response_content)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError(
|
||||
f"Response '''{response_content}''' evaluated to "
|
||||
f"non-dict value {repr(result)}"
|
||||
)
|
||||
return result
|
||||
@@ -1,115 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
from openai.types import Model
|
||||
|
||||
from autogpt.core.resource.model_providers.openai import (
|
||||
OPEN_AI_MODELS,
|
||||
OpenAICredentials,
|
||||
)
|
||||
from autogpt.core.resource.model_providers.schema import ChatModelInfo
|
||||
from autogpt.singleton import Singleton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApiManager(metaclass=Singleton):
|
||||
def __init__(self):
|
||||
self.total_prompt_tokens = 0
|
||||
self.total_completion_tokens = 0
|
||||
self.total_cost = 0
|
||||
self.total_budget = 0
|
||||
self.models: Optional[list[Model]] = None
|
||||
|
||||
def reset(self):
|
||||
self.total_prompt_tokens = 0
|
||||
self.total_completion_tokens = 0
|
||||
self.total_cost = 0
|
||||
self.total_budget = 0.0
|
||||
self.models = None
|
||||
|
||||
def update_cost(self, prompt_tokens, completion_tokens, model):
|
||||
"""
|
||||
Update the total cost, prompt tokens, and completion tokens.
|
||||
|
||||
Args:
|
||||
prompt_tokens (int): The number of tokens used in the prompt.
|
||||
completion_tokens (int): The number of tokens used in the completion.
|
||||
model (str): The model used for the API call.
|
||||
"""
|
||||
# the .model property in API responses can contain version suffixes like -v2
|
||||
model = model[:-3] if model.endswith("-v2") else model
|
||||
model_info = OPEN_AI_MODELS[model]
|
||||
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
self.total_cost += prompt_tokens * model_info.prompt_token_cost / 1000
|
||||
if isinstance(model_info, ChatModelInfo):
|
||||
self.total_cost += (
|
||||
completion_tokens * model_info.completion_token_cost / 1000
|
||||
)
|
||||
|
||||
logger.debug(f"Total running cost: ${self.total_cost:.3f}")
|
||||
|
||||
def set_total_budget(self, total_budget):
|
||||
"""
|
||||
Sets the total user-defined budget for API calls.
|
||||
|
||||
Args:
|
||||
total_budget (float): The total budget for API calls.
|
||||
"""
|
||||
self.total_budget = total_budget
|
||||
|
||||
def get_total_prompt_tokens(self):
|
||||
"""
|
||||
Get the total number of prompt tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of prompt tokens.
|
||||
"""
|
||||
return self.total_prompt_tokens
|
||||
|
||||
def get_total_completion_tokens(self):
|
||||
"""
|
||||
Get the total number of completion tokens.
|
||||
|
||||
Returns:
|
||||
int: The total number of completion tokens.
|
||||
"""
|
||||
return self.total_completion_tokens
|
||||
|
||||
def get_total_cost(self):
|
||||
"""
|
||||
Get the total cost of API calls.
|
||||
|
||||
Returns:
|
||||
float: The total cost of API calls.
|
||||
"""
|
||||
return self.total_cost
|
||||
|
||||
def get_total_budget(self):
|
||||
"""
|
||||
Get the total user-defined budget for API calls.
|
||||
|
||||
Returns:
|
||||
float: The total budget for API calls.
|
||||
"""
|
||||
return self.total_budget
|
||||
|
||||
def get_models(self, openai_credentials: OpenAICredentials) -> List[Model]:
|
||||
"""
|
||||
Get list of available GPT models.
|
||||
|
||||
Returns:
|
||||
list[Model]: List of available GPT models.
|
||||
"""
|
||||
if self.models is None:
|
||||
all_models = (
|
||||
OpenAI(**openai_credentials.get_api_access_kwargs()).models.list().data
|
||||
)
|
||||
self.models = [model for model in all_models if "gpt" in model.id]
|
||||
|
||||
return self.models
|
||||
@@ -1,10 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Callable, Iterable, TypeVar
|
||||
from typing import TYPE_CHECKING, Callable, Iterable, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.models.command import Command
|
||||
|
||||
from autogpt.core.resource.model_providers import CompletionModelFunction
|
||||
from autogpt.models.command import Command
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -12,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T", bound=Callable)
|
||||
|
||||
|
||||
def get_openai_command_specs(
|
||||
def function_specs_from_commands(
|
||||
commands: Iterable[Command],
|
||||
) -> list[CompletionModelFunction]:
|
||||
"""Get OpenAI-consumable function specs for the agent's available commands.
|
||||
@@ -20,7 +22,7 @@ def get_openai_command_specs(
|
||||
"""
|
||||
return [
|
||||
CompletionModelFunction(
|
||||
name=command.name,
|
||||
name=command.names[0],
|
||||
description=command.description,
|
||||
parameters={param.name: param.spec for param in command.parameters},
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .config import configure_chat_plugins, configure_logging
|
||||
from .config import configure_logging
|
||||
from .helpers import user_friendly_output
|
||||
from .log_cycle import (
|
||||
CURRENT_CONTEXT_FILE_NAME,
|
||||
@@ -13,7 +13,6 @@ from .log_cycle import (
|
||||
|
||||
__all__ = [
|
||||
"configure_logging",
|
||||
"configure_chat_plugins",
|
||||
"user_friendly_output",
|
||||
"CURRENT_CONTEXT_FILE_NAME",
|
||||
"NEXT_ACTION_FILE_NAME",
|
||||
|
||||
@@ -8,11 +8,9 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from auto_gpt_plugin_template import AutoGPTPluginTemplate
|
||||
from openai._base_client import log as openai_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config import Config
|
||||
from autogpt.speech import TTSConfig
|
||||
|
||||
from autogpt.core.configuration import SystemConfiguration, UserConfigurable
|
||||
@@ -34,8 +32,6 @@ DEBUG_LOG_FORMAT = (
|
||||
SPEECH_OUTPUT_LOGGER = "VOICE"
|
||||
USER_FRIENDLY_OUTPUT_LOGGER = "USER_FRIENDLY_OUTPUT"
|
||||
|
||||
_chat_plugins: list[AutoGPTPluginTemplate] = []
|
||||
|
||||
|
||||
class LogFormatName(str, enum.Enum):
|
||||
SIMPLE = "simple"
|
||||
@@ -57,8 +53,7 @@ class LoggingConfig(SystemConfiguration):
|
||||
|
||||
# Console output
|
||||
log_format: LogFormatName = UserConfigurable(
|
||||
default=LogFormatName.SIMPLE,
|
||||
from_env=lambda: LogFormatName(os.getenv("LOG_FORMAT", "simple")),
|
||||
default=LogFormatName.SIMPLE, from_env="LOG_FORMAT"
|
||||
)
|
||||
plain_console_output: bool = UserConfigurable(
|
||||
default=False,
|
||||
@@ -69,46 +64,80 @@ class LoggingConfig(SystemConfiguration):
|
||||
log_dir: Path = LOG_DIR
|
||||
log_file_format: Optional[LogFormatName] = UserConfigurable(
|
||||
default=LogFormatName.SIMPLE,
|
||||
from_env=lambda: LogFormatName(
|
||||
os.getenv("LOG_FILE_FORMAT", os.getenv("LOG_FORMAT", "simple"))
|
||||
from_env=lambda: os.getenv(
|
||||
"LOG_FILE_FORMAT", os.getenv("LOG_FORMAT", "simple")
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def configure_logging(
|
||||
level: int = logging.INFO,
|
||||
log_dir: Path = LOG_DIR,
|
||||
log_format: Optional[LogFormatName] = None,
|
||||
log_file_format: Optional[LogFormatName] = None,
|
||||
plain_console_output: bool = False,
|
||||
debug: bool = False,
|
||||
level: Optional[int | str] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
log_format: Optional[LogFormatName | str] = None,
|
||||
log_file_format: Optional[LogFormatName | str] = None,
|
||||
plain_console_output: Optional[bool] = None,
|
||||
config: Optional[LoggingConfig] = None,
|
||||
tts_config: Optional[TTSConfig] = None,
|
||||
) -> None:
|
||||
"""Configure the native logging module.
|
||||
"""Configure the native logging module, based on the environment config and any
|
||||
specified overrides.
|
||||
|
||||
Arguments override values specified in the environment.
|
||||
Overrides are also applied to `config`, if passed.
|
||||
|
||||
Should be usable as `configure_logging(**config.logging.dict())`, where
|
||||
`config.logging` is a `LoggingConfig` object.
|
||||
"""
|
||||
if debug and level:
|
||||
raise ValueError("Only one of either 'debug' and 'level' arguments may be set")
|
||||
|
||||
# Auto-adjust default log format based on log level
|
||||
log_format = log_format or (
|
||||
LogFormatName.SIMPLE if level != logging.DEBUG else LogFormatName.DEBUG
|
||||
# Parse arguments
|
||||
if isinstance(level, str):
|
||||
if type(_level := logging.getLevelName(level.upper())) is int:
|
||||
level = _level
|
||||
else:
|
||||
raise ValueError(f"Unknown log level '{level}'")
|
||||
if isinstance(log_format, str):
|
||||
if log_format in LogFormatName._value2member_map_:
|
||||
log_format = LogFormatName(log_format)
|
||||
elif not isinstance(log_format, LogFormatName):
|
||||
raise ValueError(f"Unknown log format '{log_format}'")
|
||||
if isinstance(log_file_format, str):
|
||||
if log_file_format in LogFormatName._value2member_map_:
|
||||
log_file_format = LogFormatName(log_file_format)
|
||||
elif not isinstance(log_file_format, LogFormatName):
|
||||
raise ValueError(f"Unknown log format '{log_format}'")
|
||||
|
||||
config = config or LoggingConfig.from_env()
|
||||
|
||||
# Aggregate env config + arguments
|
||||
config.level = logging.DEBUG if debug else level or config.level
|
||||
config.log_dir = log_dir or config.log_dir
|
||||
config.log_format = log_format or (
|
||||
LogFormatName.DEBUG if debug else config.log_format
|
||||
)
|
||||
config.log_file_format = log_file_format or log_format or config.log_file_format
|
||||
config.plain_console_output = (
|
||||
plain_console_output
|
||||
if plain_console_output is not None
|
||||
else config.plain_console_output
|
||||
)
|
||||
log_file_format = log_file_format or log_format
|
||||
|
||||
structured_logging = log_format == LogFormatName.STRUCTURED
|
||||
|
||||
if structured_logging:
|
||||
plain_console_output = True
|
||||
log_file_format = None
|
||||
# Structured logging is used for cloud environments,
|
||||
# where logging to a file makes no sense.
|
||||
if config.log_format == LogFormatName.STRUCTURED:
|
||||
config.plain_console_output = True
|
||||
config.log_file_format = None
|
||||
|
||||
# create log directory if it doesn't exist
|
||||
if not log_dir.exists():
|
||||
log_dir.mkdir()
|
||||
if not config.log_dir.exists():
|
||||
config.log_dir.mkdir()
|
||||
|
||||
log_handlers: list[logging.Handler] = []
|
||||
|
||||
if log_format in (LogFormatName.DEBUG, LogFormatName.SIMPLE):
|
||||
console_format_template = TEXT_LOG_FORMAT_MAP[log_format]
|
||||
if config.log_format in (LogFormatName.DEBUG, LogFormatName.SIMPLE):
|
||||
console_format_template = TEXT_LOG_FORMAT_MAP[config.log_format]
|
||||
console_formatter = AutoGptFormatter(console_format_template)
|
||||
else:
|
||||
console_formatter = StructuredLoggingFormatter()
|
||||
@@ -116,7 +145,7 @@ def configure_logging(
|
||||
|
||||
# Console output handlers
|
||||
stdout = logging.StreamHandler(stream=sys.stdout)
|
||||
stdout.setLevel(level)
|
||||
stdout.setLevel(config.level)
|
||||
stdout.addFilter(BelowLevelFilter(logging.WARNING))
|
||||
stdout.setFormatter(console_formatter)
|
||||
stderr = logging.StreamHandler()
|
||||
@@ -133,7 +162,7 @@ def configure_logging(
|
||||
user_friendly_output_logger = logging.getLogger(USER_FRIENDLY_OUTPUT_LOGGER)
|
||||
user_friendly_output_logger.setLevel(logging.INFO)
|
||||
user_friendly_output_logger.addHandler(
|
||||
typing_console_handler if not plain_console_output else stdout
|
||||
typing_console_handler if not config.plain_console_output else stdout
|
||||
)
|
||||
if tts_config:
|
||||
user_friendly_output_logger.addHandler(TTSHandler(tts_config))
|
||||
@@ -141,22 +170,26 @@ def configure_logging(
|
||||
user_friendly_output_logger.propagate = False
|
||||
|
||||
# File output handlers
|
||||
if log_file_format is not None:
|
||||
if level < logging.ERROR:
|
||||
file_output_format_template = TEXT_LOG_FORMAT_MAP[log_file_format]
|
||||
if config.log_file_format is not None:
|
||||
if config.level < logging.ERROR:
|
||||
file_output_format_template = TEXT_LOG_FORMAT_MAP[config.log_file_format]
|
||||
file_output_formatter = AutoGptFormatter(
|
||||
file_output_format_template, no_color=True
|
||||
)
|
||||
|
||||
# INFO log file handler
|
||||
activity_log_handler = logging.FileHandler(log_dir / LOG_FILE, "a", "utf-8")
|
||||
activity_log_handler.setLevel(level)
|
||||
activity_log_handler = logging.FileHandler(
|
||||
config.log_dir / LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
activity_log_handler.setLevel(config.level)
|
||||
activity_log_handler.setFormatter(file_output_formatter)
|
||||
log_handlers += [activity_log_handler]
|
||||
user_friendly_output_logger.addHandler(activity_log_handler)
|
||||
|
||||
# ERROR log file handler
|
||||
error_log_handler = logging.FileHandler(log_dir / ERROR_LOG_FILE, "a", "utf-8")
|
||||
error_log_handler = logging.FileHandler(
|
||||
config.log_dir / ERROR_LOG_FILE, "a", "utf-8"
|
||||
)
|
||||
error_log_handler.setLevel(logging.ERROR)
|
||||
error_log_handler.setFormatter(
|
||||
AutoGptFormatter(DEBUG_LOG_FORMAT, no_color=True)
|
||||
@@ -167,7 +200,7 @@ def configure_logging(
|
||||
# Configure the root logger
|
||||
logging.basicConfig(
|
||||
format=console_format_template,
|
||||
level=level,
|
||||
level=config.level,
|
||||
handlers=log_handlers,
|
||||
)
|
||||
|
||||
@@ -185,19 +218,3 @@ def configure_logging(
|
||||
|
||||
# Disable debug logging from OpenAI library
|
||||
openai_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def configure_chat_plugins(config: Config) -> None:
|
||||
"""Configure chat plugins for use by the logging module"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Add chat plugins capable of report to logger
|
||||
if config.chat_messages_enabled:
|
||||
if _chat_plugins:
|
||||
_chat_plugins.clear()
|
||||
|
||||
for plugin in config.plugins:
|
||||
if hasattr(plugin, "can_handle_report") and plugin.can_handle_report():
|
||||
logger.debug(f"Loaded plugin into logger: {plugin.__class__.__name__}")
|
||||
_chat_plugins.append(plugin)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Optional
|
||||
|
||||
from colorama import Fore
|
||||
|
||||
from .config import SPEECH_OUTPUT_LOGGER, USER_FRIENDLY_OUTPUT_LOGGER, _chat_plugins
|
||||
from .config import SPEECH_OUTPUT_LOGGER, USER_FRIENDLY_OUTPUT_LOGGER
|
||||
|
||||
|
||||
def user_friendly_output(
|
||||
@@ -21,10 +21,6 @@ def user_friendly_output(
|
||||
"""
|
||||
logger = logging.getLogger(USER_FRIENDLY_OUTPUT_LOGGER)
|
||||
|
||||
if _chat_plugins:
|
||||
for plugin in _chat_plugins:
|
||||
plugin.report(f"{title}: {message}")
|
||||
|
||||
logger.log(
|
||||
level,
|
||||
message,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from typing import Any, Sequence, overload
|
||||
|
||||
import numpy as np
|
||||
@@ -51,16 +50,9 @@ async def get_embedding(
|
||||
|
||||
if isinstance(input, str):
|
||||
input = input.replace("\n", " ")
|
||||
|
||||
with suppress(NotImplementedError):
|
||||
return _get_embedding_with_plugin(input, config)
|
||||
|
||||
elif multiple and isinstance(input[0], str):
|
||||
input = [text.replace("\n", " ") for text in input]
|
||||
|
||||
with suppress(NotImplementedError):
|
||||
return [_get_embedding_with_plugin(i, config) for i in input]
|
||||
|
||||
model = config.embedding_model
|
||||
|
||||
logger.debug(
|
||||
@@ -86,13 +78,3 @@ async def get_embedding(
|
||||
)
|
||||
embeddings.append(result.embedding)
|
||||
return embeddings
|
||||
|
||||
|
||||
def _get_embedding_with_plugin(text: str, config: Config) -> Embedding:
|
||||
for plugin in config.plugins:
|
||||
if plugin.can_handle_text_embedding(text):
|
||||
embedding = plugin.handle_text_embedding(text)
|
||||
if embedding is not None:
|
||||
return embedding
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterator, Literal, Optional
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Generic, Iterator, Literal, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.generics import GenericModel
|
||||
|
||||
from autogpt.agents.base import BaseAgentActionProposal
|
||||
from autogpt.models.utils import ModelWithSummary
|
||||
from autogpt.processing.text import summarize_text
|
||||
from autogpt.prompts.utils import format_numbered_list, indent
|
||||
|
||||
|
||||
class Action(BaseModel):
|
||||
name: str
|
||||
args: dict[str, Any]
|
||||
reasoning: str
|
||||
|
||||
def format_call(self) -> str:
|
||||
return (
|
||||
f"{self.name}"
|
||||
f"({', '.join([f'{a}={repr(v)}' for a, v in self.args.items()])})"
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from autogpt.config.config import Config
|
||||
from autogpt.core.resource.model_providers import ChatModelProvider
|
||||
|
||||
|
||||
class ActionSuccessResult(BaseModel):
|
||||
@@ -80,33 +77,62 @@ class ActionInterruptedByHuman(BaseModel):
|
||||
|
||||
ActionResult = ActionSuccessResult | ActionErrorResult | ActionInterruptedByHuman
|
||||
|
||||
AP = TypeVar("AP", bound=BaseAgentActionProposal)
|
||||
|
||||
class Episode(BaseModel):
|
||||
action: Action
|
||||
|
||||
class Episode(GenericModel, Generic[AP]):
|
||||
action: AP
|
||||
result: ActionResult | None
|
||||
summary: str | None = None
|
||||
|
||||
def format(self):
|
||||
step = f"Executed `{self.action.use_tool}`\n"
|
||||
reasoning = (
|
||||
_r.summary()
|
||||
if isinstance(_r := self.action.thoughts, ModelWithSummary)
|
||||
else _r
|
||||
)
|
||||
step += f'- **Reasoning:** "{reasoning}"\n'
|
||||
step += (
|
||||
"- **Status:** "
|
||||
f"`{self.result.status if self.result else 'did_not_finish'}`\n"
|
||||
)
|
||||
if self.result:
|
||||
if self.result.status == "success":
|
||||
result = str(self.result)
|
||||
result = "\n" + indent(result) if "\n" in result else result
|
||||
step += f"- **Output:** {result}"
|
||||
elif self.result.status == "error":
|
||||
step += f"- **Reason:** {self.result.reason}\n"
|
||||
if self.result.error:
|
||||
step += f"- **Error:** {self.result.error}\n"
|
||||
elif self.result.status == "interrupted_by_human":
|
||||
step += f"- **Feedback:** {self.result.feedback}\n"
|
||||
return step
|
||||
|
||||
def __str__(self) -> str:
|
||||
executed_action = f"Executed `{self.action.format_call()}`"
|
||||
executed_action = f"Executed `{self.action.use_tool}`"
|
||||
action_result = f": {self.result}" if self.result else "."
|
||||
return executed_action + action_result
|
||||
|
||||
|
||||
class EpisodicActionHistory(BaseModel):
|
||||
class EpisodicActionHistory(GenericModel, Generic[AP]):
|
||||
"""Utility container for an action history"""
|
||||
|
||||
episodes: list[Episode] = Field(default_factory=list)
|
||||
episodes: list[Episode[AP]] = Field(default_factory=list)
|
||||
cursor: int = 0
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def current_episode(self) -> Episode | None:
|
||||
def current_episode(self) -> Episode[AP] | None:
|
||||
if self.cursor == len(self):
|
||||
return None
|
||||
return self[self.cursor]
|
||||
|
||||
def __getitem__(self, key: int) -> Episode:
|
||||
def __getitem__(self, key: int) -> Episode[AP]:
|
||||
return self.episodes[key]
|
||||
|
||||
def __iter__(self) -> Iterator[Episode]:
|
||||
def __iter__(self) -> Iterator[Episode[AP]]:
|
||||
return iter(self.episodes)
|
||||
|
||||
def __len__(self) -> int:
|
||||
@@ -115,7 +141,7 @@ class EpisodicActionHistory(BaseModel):
|
||||
def __bool__(self) -> bool:
|
||||
return len(self.episodes) > 0
|
||||
|
||||
def register_action(self, action: Action) -> None:
|
||||
def register_action(self, action: AP) -> None:
|
||||
if not self.current_episode:
|
||||
self.episodes.append(Episode(action=action, result=None))
|
||||
assert self.current_episode
|
||||
@@ -148,29 +174,48 @@ class EpisodicActionHistory(BaseModel):
|
||||
self.episodes = self.episodes[:-number_of_episodes]
|
||||
self.cursor = len(self.episodes)
|
||||
|
||||
async def handle_compression(
|
||||
self, llm_provider: ChatModelProvider, app_config: Config
|
||||
) -> None:
|
||||
"""Compresses each episode in the action history using an LLM.
|
||||
|
||||
This method iterates over all episodes in the action history without a summary,
|
||||
and generates a summary for them using an LLM.
|
||||
"""
|
||||
compress_instruction = (
|
||||
"The text represents an action, the reason for its execution, "
|
||||
"and its result. "
|
||||
"Condense the action taken and its result into one line. "
|
||||
"Preserve any specific factual information gathered by the action."
|
||||
)
|
||||
async with self._lock:
|
||||
# Gather all episodes without a summary
|
||||
episodes_to_summarize = [ep for ep in self.episodes if ep.summary is None]
|
||||
|
||||
# Parallelize summarization calls
|
||||
summarize_coroutines = [
|
||||
summarize_text(
|
||||
episode.format(),
|
||||
instruction=compress_instruction,
|
||||
llm_provider=llm_provider,
|
||||
config=app_config,
|
||||
)
|
||||
for episode in episodes_to_summarize
|
||||
]
|
||||
summaries = await asyncio.gather(*summarize_coroutines)
|
||||
|
||||
# Assign summaries to episodes
|
||||
for episode, (summary, _) in zip(episodes_to_summarize, summaries):
|
||||
episode.summary = summary
|
||||
|
||||
def fmt_list(self) -> str:
|
||||
return format_numbered_list(self.episodes)
|
||||
|
||||
def fmt_paragraph(self) -> str:
|
||||
steps: list[str] = []
|
||||
|
||||
for i, c in enumerate(self.episodes, 1):
|
||||
step = f"### Step {i}: Executed `{c.action.format_call()}`\n"
|
||||
step += f'- **Reasoning:** "{c.action.reasoning}"\n'
|
||||
step += (
|
||||
f"- **Status:** `{c.result.status if c.result else 'did_not_finish'}`\n"
|
||||
)
|
||||
if c.result:
|
||||
if c.result.status == "success":
|
||||
result = str(c.result)
|
||||
result = "\n" + indent(result) if "\n" in result else result
|
||||
step += f"- **Output:** {result}"
|
||||
elif c.result.status == "error":
|
||||
step += f"- **Reason:** {c.result.reason}\n"
|
||||
if c.result.error:
|
||||
step += f"- **Error:** {c.result.error}\n"
|
||||
elif c.result.status == "interrupted_by_human":
|
||||
step += f"- **Feedback:** {c.result.feedback}\n"
|
||||
for i, episode in enumerate(self.episodes, 1):
|
||||
step = f"### Step {i}: {episode.format()}\n"
|
||||
|
||||
steps.append(step)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user