mirror of
https://github.com/googleapis/genai-toolbox.git
synced 2026-02-18 11:02:26 -05:00
Compare commits
11 Commits
auth-files
...
lsc-177138
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
41b04ea66c | ||
|
|
a21d9a158b | ||
|
|
2618bd7673 | ||
|
|
8ea39ec32f | ||
|
|
aa270b2630 | ||
|
|
e1bd98ef5b | ||
|
|
fa148c60a7 | ||
|
|
6e87349431 | ||
|
|
3fe4e2b671 | ||
|
|
271f39d4b9 | ||
|
|
97b0e7d3ac |
@@ -305,4 +305,4 @@ substitutions:
|
||||
_AR_HOSTNAME: ${_REGION}-docker.pkg.dev
|
||||
_AR_REPO_NAME: toolbox-dev
|
||||
_BUCKET_NAME: genai-toolbox-dev
|
||||
_DOCKER_URI: ${_AR_HOSTNAME}/${PROJECT_ID}/${_AR_REPO_NAME}/toolbox
|
||||
_DOCKER_URI: ${_AR_HOSTNAME}/${PROJECT_ID}/${_AR_REPO_NAME}/toolbox
|
||||
@@ -212,6 +212,26 @@ steps:
|
||||
bigquery \
|
||||
bigquery
|
||||
|
||||
- id: "cloud-gda"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
- "CLOUD_GDA_PROJECT=$PROJECT_ID"
|
||||
- "SERVICE_ACCOUNT_EMAIL=$SERVICE_ACCOUNT_EMAIL"
|
||||
secretEnv: ["CLIENT_ID"]
|
||||
volumes:
|
||||
- name: "go"
|
||||
path: "/gopath"
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"Cloud Gemini Data Analytics" \
|
||||
cloudgda \
|
||||
cloudgda
|
||||
|
||||
- id: "dataplex"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
@@ -826,8 +846,8 @@ steps:
|
||||
cassandra
|
||||
|
||||
- id: "oracle"
|
||||
name: golang:1
|
||||
waitFor: ["compile-test-binary"]
|
||||
name: ghcr.io/oracle/oraclelinux9-instantclient:23
|
||||
waitFor: ["install-dependencies"]
|
||||
entrypoint: /bin/bash
|
||||
env:
|
||||
- "GOPATH=/gopath"
|
||||
@@ -840,10 +860,25 @@ steps:
|
||||
args:
|
||||
- -c
|
||||
- |
|
||||
.ci/test_with_coverage.sh \
|
||||
"Oracle" \
|
||||
oracle \
|
||||
oracle
|
||||
# Install the C compiler and Oracle SDK headers needed for cgo
|
||||
dnf install -y gcc oracle-instantclient-devel
|
||||
# Install Go
|
||||
curl -L -o go.tar.gz "https://go.dev/dl/go1.25.1.linux-amd64.tar.gz"
|
||||
tar -C /usr/local -xzf go.tar.gz
|
||||
export PATH="/usr/local/go/bin:$$PATH"
|
||||
|
||||
go test -v ./internal/sources/oracle/... \
|
||||
-coverprofile=oracle_coverage.out \
|
||||
-coverpkg=./internal/sources/oracle/...,./internal/tools/oracle/...
|
||||
|
||||
# Coverage check
|
||||
total_coverage=$(go tool cover -func=oracle_coverage.out | grep "total:" | awk '{print $3}')
|
||||
echo "Oracle total coverage: $total_coverage"
|
||||
coverage_numeric=$(echo "$total_coverage" | sed 's/%//')
|
||||
if awk -v cov="$coverage_numeric" 'BEGIN {exit !(cov < 30)}'; then
|
||||
echo "Coverage failure: $total_coverage is below 30%."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- id: "serverless-spark"
|
||||
name: golang:1
|
||||
|
||||
10
.github/labels.yaml
vendored
10
.github/labels.yaml
vendored
@@ -83,10 +83,16 @@
|
||||
- name: 'status: feedback wanted'
|
||||
color: 8befd7
|
||||
description: 'Status: waiting for feedback from community or issue author.'
|
||||
|
||||
- name: 'status: waiting for response'
|
||||
color: 8befd7
|
||||
description: 'Status: reviewer is awaiting feedback or responses from the author before proceeding.'
|
||||
- name: 'status: need-triage'
|
||||
color: 8befd7
|
||||
description: 'Status: Issues that needs to be triaged by the triage automation.'
|
||||
- name: 'status: manual-triage'
|
||||
color: 8befd7
|
||||
description: 'Status: Issues that needs to be triaged by the maintainers.'
|
||||
|
||||
|
||||
- name: 'release candidate'
|
||||
color: 32CD32
|
||||
@@ -179,4 +185,4 @@
|
||||
description: 'Valkey'
|
||||
- name: 'product: yugabytedb'
|
||||
color: 5065c7
|
||||
description: 'YugabyteDB'
|
||||
description: 'YugabyteDB'
|
||||
4
.github/workflows/deploy_versioned_docs.yaml
vendored
4
.github/workflows/deploy_versioned_docs.yaml
vendored
@@ -35,7 +35,9 @@ jobs:
|
||||
ref: ${{ github.event.release.tag_name }}
|
||||
|
||||
- name: Get Version from Release Tag
|
||||
run: echo "VERSION=${{ github.event.release.tag_name }}" >> $GITHUB_ENV
|
||||
run: echo "VERSION=${GITHUB_EVENT_RELEASE_TAG_NAME}" >> $GITHUB_ENV
|
||||
env:
|
||||
GITHUB_EVENT_RELEASE_TAG_NAME: ${{ github.event.release.tag_name }}
|
||||
|
||||
- name: Setup Hugo
|
||||
uses: peaceiris/actions-hugo@75d2e84710de30f6ff7268e08f310b60ef14033f # v3
|
||||
|
||||
396
.github/workflows/gemini_issue_triage.yaml
vendored
Normal file
396
.github/workflows/gemini_issue_triage.yaml
vendored
Normal file
@@ -0,0 +1,396 @@
|
||||
# Copyright 2026 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
name: '🏷️ Gemini Issue Triage'
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * *' # Runs everyday at midnight
|
||||
issues:
|
||||
types:
|
||||
- 'opened' # automated triage when issue opened
|
||||
workflow_dispatch: # manually dispatch workflow
|
||||
inputs:
|
||||
issue_number:
|
||||
description: 'issue number to triage'
|
||||
required: false # set to false so can manually run bulk scan as well
|
||||
type: 'number'
|
||||
|
||||
concurrency:
|
||||
group: '${{ github.workflow }}-${{ github.event.issue.number || github.event.inputs.issue_number || scheduled }}'
|
||||
cancel-in-progress: true
|
||||
|
||||
defaults:
|
||||
run:
|
||||
shell: 'bash'
|
||||
|
||||
permissions:
|
||||
contents: 'read'
|
||||
id-token: 'write'
|
||||
issues: 'write'
|
||||
statuses: 'write'
|
||||
packages: 'read'
|
||||
actions: 'write' # Required for cancelling a workflow run
|
||||
|
||||
jobs:
|
||||
triage-issue:
|
||||
if: |-
|
||||
github.repository == 'googleapis/genai-toolbox' && !contains(github.event.issue.labels.*.name, 'priority:')
|
||||
timeout-minutes: 10
|
||||
runs-on: 'ubuntu-latest'
|
||||
steps:
|
||||
- name: 'Get issue data for manual trigger'
|
||||
id: 'get_issue_data'
|
||||
if: |-
|
||||
github.event_name == 'workflow_dispatch'
|
||||
uses: 'actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea'
|
||||
with:
|
||||
github-token: '${{ secrets.GITHUB_TOKEN }}'
|
||||
script: |
|
||||
const { data: issue } = await github.rest.issues.get({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: ${{ github.event.inputs.issue_number }},
|
||||
});
|
||||
core.setOutput('title', issue.title);
|
||||
core.setOutput('body', issue.body);
|
||||
core.setOutput('labels', issue.labels.map(label => label.name).join(','));
|
||||
return issue;
|
||||
|
||||
- name: 'Manual Trigger Pre-flight Checks'
|
||||
if: |-
|
||||
github.event_name == 'workflow_dispatch'
|
||||
env:
|
||||
ISSUE_NUMBER_INPUT: '${{ github.event.inputs.issue_number }}'
|
||||
LABELS: '${{ steps.get_issue_data.outputs.labels }}'
|
||||
run: |
|
||||
if echo "${LABELS}" | grep -q 'priority:'; then
|
||||
echo "Issue #${ISSUE_NUMBER_INPUT} already has 'priority:' labels. Stopping workflow."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Manual triage checks passed."
|
||||
|
||||
- name: 'Checkout'
|
||||
uses: 'actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8' # ratchet:actions/checkout@v5
|
||||
|
||||
- name: 'Get Repository Labels'
|
||||
id: 'get_labels'
|
||||
uses: 'actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea'
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |-
|
||||
// Fetch ALL labels (handling pagination automatically)
|
||||
const labels = await github.paginate(github.rest.issues.listLabelsForRepo, {
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
per_page: 100
|
||||
});
|
||||
|
||||
// Only grab labels with specific prefix
|
||||
const targetPrefixes = ['priority:', 'product:', 'type:'];
|
||||
const labelNames = labels.map(label => label.name).filter(name =>
|
||||
targetPrefixes.some(prefix => name.startsWith(prefix)));
|
||||
|
||||
// Export labels
|
||||
core.setOutput('available_labels', labelNames.join(','));
|
||||
core.info(`Found ${labelNames.length} labels: ${labelNames.join(', ')}`);
|
||||
return labelNames;
|
||||
|
||||
- name: 'Find untriaged issues'
|
||||
id: 'find_issues'
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
GITHUB_REPOSITORY: '${{ github.repository }}'
|
||||
ISSUE_NUMBER: '${{ github.event.issue.number || github.event.inputs.issue_number }}'
|
||||
run: |-
|
||||
set -euo pipefail
|
||||
|
||||
ISSUES="[]"
|
||||
|
||||
if [[ -n "${ISSUE_NUMBER}" ]]; then
|
||||
echo "🎯 Single Issue Mode: Processing #${ISSUE_NUMBER}..."
|
||||
|
||||
SINGLE_DATA="$(gh issue view "${ISSUE_NUMBER}" \
|
||||
--repo "${GITHUB_REPOSITORY}" \
|
||||
--json number,title,body)"
|
||||
|
||||
ISSUES="[${SINGLE_DATA}]"
|
||||
else
|
||||
echo "📅 Bulk Mode: Running full triage scan..."
|
||||
echo '🔍 Finding issues without labels...'
|
||||
NO_LABEL_ISSUES="$(gh issue list --repo "${GITHUB_REPOSITORY}" \
|
||||
--search 'is:open is:issue no:label' --json number,title,body)"
|
||||
|
||||
echo '🏷️ Finding issues that need triage...'
|
||||
NEED_TRIAGE_ISSUES="$(gh issue list --repo "${GITHUB_REPOSITORY}" \
|
||||
--search "is:open is:issue label:\"status: need-triage\" -label:\"status: manual-triage\"" --limit 1000 --json number,title,body)"
|
||||
|
||||
echo '🔄 Merging and deduplicating issues...'
|
||||
ISSUES="$(echo "${NO_LABEL_ISSUES}" "${NEED_TRIAGE_ISSUES}" | jq -c -s 'add | unique_by(.number)')"
|
||||
fi
|
||||
|
||||
echo '📝 Setting output for GitHub Actions...'
|
||||
echo "issues_to_triage=${ISSUES}" >> "${GITHUB_OUTPUT}"
|
||||
|
||||
ISSUE_COUNT="$(echo "${ISSUES}" | jq 'length')"
|
||||
echo "✅ Found ${ISSUE_COUNT} issues to triage! 🎯"
|
||||
|
||||
- name: 'Run Gemini Issue Analysis'
|
||||
if: |- # skip workflow if its a scheduled workflow without any issues to triage
|
||||
${{ !(github.event_name == 'schedule' &&
|
||||
steps.find_issues.outputs.issues_to_triage == '[]') }}
|
||||
uses: 'google-github-actions/run-gemini-cli@a3bf79042542528e91937b3a3a6fbc4967ee3c31' # ratchet:google-github-actions/run-gemini-cli@v0
|
||||
id: 'gemini_issue_analysis'
|
||||
env:
|
||||
GITHUB_TOKEN: '' # Do not pass any auth token here since this runs on untrusted inputs
|
||||
ISSUES_TO_TRIAGE: '${{ steps.find_issues.outputs.issues_to_triage }}'
|
||||
REPOSITORY: '${{ github.repository }}'
|
||||
AVAILABLE_LABELS: '${{ steps.get_labels.outputs.available_labels }}'
|
||||
with:
|
||||
gcp_workload_identity_provider: '${{ vars.GCP_WIF_PROVIDER }}'
|
||||
gcp_project_id: '${{ vars.GOOGLE_CLOUD_PROJECT }}'
|
||||
gcp_location: '${{ vars.GOOGLE_CLOUD_LOCATION }}'
|
||||
gcp_service_account: '${{ vars.SERVICE_ACCOUNT_EMAIL }}'
|
||||
gemini_api_key: '${{ secrets.GEMINI_API_KEY }}'
|
||||
use_vertex_ai: '${{ vars.GOOGLE_GENAI_USE_VERTEXAI }}'
|
||||
use_gemini_code_assist: '${{ vars.GOOGLE_GENAI_USE_GCA }}'
|
||||
settings: |-
|
||||
{
|
||||
"maxSessionTurns": 25,
|
||||
"telemetry": {
|
||||
"enabled": true,
|
||||
"target": "gcp"
|
||||
}
|
||||
}
|
||||
prompt: |-
|
||||
## Role
|
||||
|
||||
You are an issue triage assistant. Your role is to analyze a GitHub
|
||||
issue and identify appropriate labels based on the definitions
|
||||
provided.
|
||||
|
||||
## Steps
|
||||
1. Check environment variable for issues to triage: $ISSUES_TO_TRIAGE (JSON array of issues).
|
||||
2. Review the available labels: ${{ env.AVAILABLE_LABELS }}.
|
||||
3. Identify the most relevant labels from the existing labels,
|
||||
focusing on 'priority: *', 'type: *', and 'product: *'.
|
||||
4. If the issue already has a 'product: *' label, do not try to
|
||||
change it. If the issue already has a 'type: *' label, do not try to
|
||||
change it. If the issue already has a 'priority: *' label, do not
|
||||
try to change it. For example, if an issue already has a 'product:
|
||||
*' label, you wil only add a 'type: *' and/or 'priority: *' label.
|
||||
Instead, if an issue has no labels, you could add one labels of each
|
||||
kind.
|
||||
5. Fallback Logic:
|
||||
- If you cannot confidently determine the correct 'product: *' label
|
||||
from the definitions, feel free to leave it.
|
||||
- If you cannot confidently determine the correct 'type: *' label
|
||||
from the definitions, feel free to leave it.
|
||||
- If you cannot confidently determine the correct 'priority: *'
|
||||
label from the definitions, apply the 'status: manual-triage'
|
||||
label.
|
||||
6. Give me a single short explanation about why you are selecting
|
||||
each label in the process.
|
||||
7. Output a JSON array of objects, each containing the issue number
|
||||
and the labels to add and remove, along with an explanation and
|
||||
nothing else. Example:
|
||||
```
|
||||
[
|
||||
{
|
||||
"issue_number": 123,
|
||||
"labels_to_add": ["product: alloydb", "priority: p2"],
|
||||
"labels_to_remove": ["status: need-triage"],
|
||||
"explanation": "This issue is a bug within the alloydb tool that needs to be addressed with medium priority."
|
||||
}
|
||||
]
|
||||
```
|
||||
8. If you see that the issue doesn't look like it has sufficient
|
||||
information, leave a comment politely requesting the relevant
|
||||
information.
|
||||
- After identifying appropriate labels to an issue, add "status:
|
||||
need-triage" label to labels_to_remove in the output.
|
||||
10. If you think an issue might be a 'priority: p0' do not apply the
|
||||
'priority: p0' label. Instead, apply a 'status: manual-triage' label
|
||||
and include a note in your explanation.
|
||||
|
||||
|
||||
## Guidelines
|
||||
- Your output must contain exactly one priority: label.
|
||||
- Output only valid JSON format.
|
||||
- Do not include any explanation or additional text, just the JSON.
|
||||
- Only use labels that already exist in the repository.
|
||||
- Do not add comments or modify the issue content.
|
||||
- Triage only the current issue.
|
||||
- Identify only one 'product: *' label
|
||||
- Identify applicable 'priority: *' labels based on the issue content.
|
||||
- Once you categorize the issue if it needs information, bump down the priority by 1 eg.. a p0 would become a p1 a p1 would become a p2. P2 and P3 can stay as is in this scenario.
|
||||
|
||||
Guidelines for Priority labels
|
||||
'priority: p0': Critical / Blocker
|
||||
- Definition: A catastrophic failure that makes the server unusable for most users or poses a severe security risk. This includes installation failures, authentication failures, persistent crashes, or critical security vulnerabilities.
|
||||
- Key Questions:
|
||||
- Is the main goal of the tool (e.g., connecting an agent to a database) completely impossible?
|
||||
- Is the server failing to install or run?
|
||||
- Does this represent a critical security vulnerability?
|
||||
- Does this block existing user and have to be resolved immediately in order to utilize the server again?
|
||||
- Does this issue affect every user immediately upon running the latest version?
|
||||
- Is there absolutely no temporary workaround or alternative method to achieve the desired result?
|
||||
|
||||
'priority: p1': High
|
||||
- Definition: A severe issue that causes a significant degradation of a key feature, produces incorrect or inconsistent results, or severely impacts a large number of users. It requires prompt resolution, though a temporary workaround might exist. This also includes critical missing documentation for core features.
|
||||
- Key Questions:
|
||||
- Does this issue affect a key component that is widely relied upon (e.g., core database operations)?
|
||||
- Are the results produced by the tool incorrect, misleading, or unreliable?
|
||||
- Is a feature failing for a specific, large user group (e.g., all Windows users, all users of a specific shell)?
|
||||
- Does a user need to perform difficult, undocumented steps to work around the problem?
|
||||
- Is essential setup or usage documentation completely missing for a new feature?
|
||||
|
||||
'priority: p2': Medium
|
||||
- Definition: A moderately impactful issue causing inconvenience or a non-optimal experience, but a reasonable workaround exists. This also includes failures in non-core features.
|
||||
- Key Questions:
|
||||
- Is the issue a standard bug fix that only affects a smaller, non-critical area of the code?
|
||||
- Is this a clear, actionable enhancement that adds tangible value without being mission-critical?
|
||||
- Can the user easily and reliably work around the issue without major difficulty?
|
||||
- Is this an overdue technical debt item or a minor documentation correction?
|
||||
|
||||
'priority: p3': Low
|
||||
- Definition: A minor, low-impact issue with minimal effect on functionality. This includes most cosmetic defects, typos in documentation, or unclear help text. They have minimal to no impact on the current functionality or user experience and can be addressed when time and resources allow.
|
||||
- Key Questions:
|
||||
- Is this a typo in the README.md, gemini --help text, or other documentation?
|
||||
- Is this a minor cosmetic issue (e.g., text alignment in output, an extra newline) that doesn't affect usability?
|
||||
- Is the issue a minor cleanup or refactoring that doesn't fix a current problem but improves code style?
|
||||
- Can this be ignored for several release cycles without negatively impacting users?
|
||||
|
||||
Guidelines for Product labels
|
||||
If the issue is specific towards a product, add the product label.
|
||||
For example, alloydb related issue should be assigned the 'product:
|
||||
alloydb' label. The available 'product: *' labels are included in
|
||||
the list of available products.
|
||||
|
||||
Guidelines for Type labels
|
||||
Assign the issue based on type. The available 'type: *' labels are
|
||||
included in the list of available labels.
|
||||
'type: bug'
|
||||
- Error or flaw in code with unintended results or allowing sub-optimal usage patterns.
|
||||
'type: cleanup'
|
||||
- An internal cleanup or hygiene concern.
|
||||
'type: docs'
|
||||
- Improvement to the documentation for an API.
|
||||
'type: feature request'
|
||||
- ‘Nice-to-have’ improvement, new feature or different behavior or design.
|
||||
'type: process'
|
||||
- A process-related concern. May include testing, release, or the like.
|
||||
'type: question'
|
||||
- Request for information or clarification.
|
||||
|
||||
- name: 'Apply Labels to Issue'
|
||||
if: |-
|
||||
${{ steps.gemini_issue_analysis.outcome == 'success' &&
|
||||
steps.gemini_issue_analysis.outputs.summary != '[]' }}
|
||||
env:
|
||||
REPOSITORY: '${{ github.repository }}'
|
||||
LABELS_OUTPUT: '${{ steps.gemini_issue_analysis.outputs.summary }}'
|
||||
uses: 'actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea'
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |
|
||||
const rawOutput = process.env.LABELS_OUTPUT;
|
||||
core.info(`Raw output from model: ${rawOutput}`);
|
||||
let parsedLabels;
|
||||
|
||||
try {
|
||||
const jsonMatch = rawLabels.match(/```json\s*([\s\S]*?)\s*```/);
|
||||
if (!jsonMatch || !jsonMatch[1]) {
|
||||
throw new Error("Could not find a ```json ... ``` block in the output.");
|
||||
}
|
||||
const jsonString = jsonMatch[1].trim();
|
||||
parsedLabels = JSON.parse(jsonString);
|
||||
core.info(`Parsed labels JSON: ${JSON.stringify(parsedLabels)}`);
|
||||
} catch (err) {
|
||||
core.setFailed(`Failed to parse labels JSON from Gemini output: ${err.message}\nRaw output: ${rawLabels}`);
|
||||
return;
|
||||
}
|
||||
|
||||
for (const entry of parsedLabels) {
|
||||
const issueNumber = entry.issue_number;
|
||||
if (!issueNumber) {
|
||||
core.info(`Skipping entry with no issue number: ${JSON.stringify(entry)}`);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (labelsToAdd.length > 0) {
|
||||
await github.rest.issues.addLabels({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issueNumber,
|
||||
labels: labelsToAdd
|
||||
});
|
||||
const explanation = entry.explanation ? ` - ${entry.explanation}` : '';
|
||||
core.info(`Successfully added labels for #${issueNumber}: ${labelsToAdd.join(', ')}${explanation}`);
|
||||
}
|
||||
|
||||
if (entry.labels_to_remove && entry.labels_to_remove.length > 0) {
|
||||
for (const label of entry.labels_to_remove) {
|
||||
try {
|
||||
await github.rest.issues.removeLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issueNumber,
|
||||
name: label
|
||||
});
|
||||
} catch (error) {
|
||||
if (error.status !== 404) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
core.info(`Successfully removed labels for #${issueNumber}: ${entry.labels_to_remove.join(', ')}`);
|
||||
}
|
||||
|
||||
if (entry.explanation) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: issueNumber,
|
||||
body: entry.explanation,
|
||||
});
|
||||
}
|
||||
|
||||
if ((!entry.labels_to_add || entry.labels_to_add.length === 0) && (!entry.labels_to_remove || entry.labels_to_remove.length === 0)) {
|
||||
core.info(`No labels to add or remove for #${issueNumber}, leaving as is`);
|
||||
}
|
||||
}
|
||||
|
||||
- name: 'Post Issue Analysis Failure Comment' # only post failure comment for open issues and manual workflow dispatch
|
||||
if: |-
|
||||
${{
|
||||
github.event_name != 'schedule' &&
|
||||
failure() &&
|
||||
steps.gemini_issue_analysis.outcome == 'failure'
|
||||
}}
|
||||
env:
|
||||
ISSUES_TO_TRIAGE: '${{ steps.find_issues.outputs.issues_to_triage }}'
|
||||
ISSUE_NUMBER: '${{ github.event.issue.number || github.event.inputs.issue_number }}'
|
||||
RUN_URL: '${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}'
|
||||
uses: 'actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea'
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
script: |-
|
||||
github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: parseInt(process.env.ISSUE_NUMBER),
|
||||
body: 'There is a problem with the Gemini CLI issue triaging. Please check the [action logs](${process.env.RUN_URL}) for details.'
|
||||
})
|
||||
@@ -167,15 +167,15 @@ tools.
|
||||
[integration.cloudbuild.yaml](.ci/integration.cloudbuild.yaml).
|
||||
|
||||
[tool-get]:
|
||||
https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L31
|
||||
https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L41
|
||||
[tool-call]:
|
||||
<https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L79>
|
||||
https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L229
|
||||
[mcp-call]:
|
||||
https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L554
|
||||
https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L789
|
||||
[execute-sql]:
|
||||
<https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L431>
|
||||
https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L609
|
||||
[temp-param]:
|
||||
<https://github.com/googleapis/genai-toolbox/blob/fd300dc606d88bf9f7bba689e2cee4e3565537dd/tests/tool.go#L297>
|
||||
https://github.com/googleapis/genai-toolbox/blob/v0.23.0/tests/tool.go#L454
|
||||
[temp-param-doc]:
|
||||
https://googleapis.github.io/genai-toolbox/resources/tools/#template-parameters
|
||||
|
||||
@@ -189,6 +189,10 @@ tools.
|
||||
|
||||
* **(Optional) Add samples** to the `docs/en/samples/<newdb>` directory.
|
||||
|
||||
### Updating labels
|
||||
|
||||
* Add a `product: <source>` label in `.github/labels.yaml`
|
||||
|
||||
### (Optional) Adding Prebuilt Tools
|
||||
|
||||
You can provide developers with a set of "build-time" tools to aid common
|
||||
|
||||
@@ -73,6 +73,7 @@ import (
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselistdatabases"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhouselisttables"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/clickhouse/clickhousesql"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirfetchpage"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatienteverything"
|
||||
_ "github.com/googleapis/genai-toolbox/internal/tools/cloudhealthcare/cloudhealthcarefhirpatientsearch"
|
||||
|
||||
@@ -183,11 +183,11 @@ Protocol (OTLP). If you would like to use a collector, please refer to this
|
||||
|
||||
The following flags are used to determine Toolbox's telemetry configuration:
|
||||
|
||||
| **flag** | **type** | **description** |
|
||||
|----------------------------|----------|------------------------------------------------------------------------------------------------------------------|
|
||||
| `--telemetry-gcp` | bool | Enable exporting directly to Google Cloud Monitoring. Default is `false`. |
|
||||
| `--telemetry-otlp` | string | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. "<http://127.0.0.1:4318>"). |
|
||||
| `--telemetry-service-name` | string | Sets the value of the `service.name` resource attribute. Default is `toolbox`. |
|
||||
| **flag** | **type** | **description** |
|
||||
|----------------------------|----------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `--telemetry-gcp` | bool | Enable exporting directly to Google Cloud Monitoring. Default is `false`. |
|
||||
| `--telemetry-otlp` | string | Enable exporting using OpenTelemetry Protocol (OTLP) to the specified endpoint (e.g. "127.0.0.1:4318"). To pass an insecure endpoint here, set environment variable `OTEL_EXPORTER_OTLP_INSECURE=true`. |
|
||||
| `--telemetry-service-name` | string | Sets the value of the `service.name` resource attribute. Default is `toolbox`. |
|
||||
|
||||
In addition to the flags noted above, you can also make additional configuration
|
||||
for OpenTelemetry via the [General SDK Configuration][sdk-configuration] through
|
||||
@@ -207,5 +207,5 @@ To enable Google Cloud Exporter:
|
||||
To enable OTLP Exporter, provide Collector endpoint:
|
||||
|
||||
```bash
|
||||
./toolbox --telemetry-otlp="http://127.0.0.1:4553"
|
||||
./toolbox --telemetry-otlp="127.0.0.1:4553"
|
||||
```
|
||||
|
||||
@@ -872,11 +872,12 @@
|
||||
}
|
||||
},
|
||||
"node_modules/jws": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/jws/-/jws-4.0.0.tgz",
|
||||
"integrity": "sha512-KDncfTmOZoOMTFG4mBlG0qUIOlc03fmzH+ru6RgYVZhPkyiy/92Owlt/8UEN+a4TXR1FQetfIpJE8ApdvdVxTg==",
|
||||
"version": "4.0.1",
|
||||
"resolved": "https://registry.npmjs.org/jws/-/jws-4.0.1.tgz",
|
||||
"integrity": "sha512-EKI/M/yqPncGUUh44xz0PxSidXFr/+r0pA70+gIYhjv+et7yxM+s29Y+VGDkovRofQem0fs7Uvf4+YmAdyRduA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"jwa": "^2.0.0",
|
||||
"jwa": "^2.0.1",
|
||||
"safe-buffer": "^5.0.1"
|
||||
}
|
||||
},
|
||||
|
||||
@@ -79,12 +79,16 @@ There are a couple of steps to run and use a Collector.
|
||||
```
|
||||
|
||||
1. Run toolbox with the `--telemetry-otlp` flag. Configure it to send them to
|
||||
`http://127.0.0.1:4553` (for HTTP) or the Collector's URL.
|
||||
`127.0.0.1:4553` (for HTTP) or the Collector's URL.
|
||||
|
||||
```bash
|
||||
./toolbox --telemetry-otlp=http://127.0.0.1:4553
|
||||
./toolbox --telemetry-otlp=127.0.0.1:4553
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
To pass an insecure endpoint, set environment variable `OTEL_EXPORTER_OTLP_INSECURE=true`.
|
||||
{{< /notice >}}
|
||||
|
||||
1. Once telemetry datas are collected, you can view them in your telemetry
|
||||
backend. If you are using GCP exporters, telemetry will be visible in GCP
|
||||
dashboard at [Metrics Explorer][metrics-explorer] and [Trace
|
||||
|
||||
40
docs/en/resources/sources/cloud-gda.md
Normal file
40
docs/en/resources/sources/cloud-gda.md
Normal file
@@ -0,0 +1,40 @@
|
||||
---
|
||||
title: "Gemini Data Analytics"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A "cloud-gemini-data-analytics" source provides a client for the Gemini Data Analytics API.
|
||||
aliases:
|
||||
- /resources/sources/cloud-gemini-data-analytics
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `cloud-gemini-data-analytics` source provides a client to interact with the [Gemini Data Analytics API](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/reference/rest). This allows tools to send natural language queries to the API.
|
||||
|
||||
Authentication can be handled in two ways:
|
||||
|
||||
1. **Application Default Credentials (ADC) (Recommended):** By default, the source uses ADC to authenticate with the API. The Toolbox server will fetch the credentials from its running environment (server-side authentication). This is the recommended method.
|
||||
2. **Client-side OAuth:** If `useClientOAuth` is set to `true`, the source expects the authentication token to be provided by the caller when making a request to the Toolbox server (typically via an HTTP Bearer token). The Toolbox server will then forward this token to the underlying Gemini Data Analytics API calls.
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-gda-source:
|
||||
kind: cloud-gemini-data-analytics
|
||||
projectId: my-project-id
|
||||
|
||||
my-oauth-gda-source:
|
||||
kind: cloud-gemini-data-analytics
|
||||
projectId: my-project-id
|
||||
useClientOAuth: true
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| -------------- | :------: | :----------: | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| kind | string | true | Must be "cloud-gemini-data-analytics". |
|
||||
| projectId | string | true | The Google Cloud Project ID where the API is enabled. |
|
||||
| useClientOAuth | boolean | false | If true, the source uses the token provided by the caller (forwarded to the API). Otherwise, it uses server-side Application Default Credentials (ADC). Defaults to `false`. |
|
||||
@@ -18,10 +18,10 @@ DW) database workloads.
|
||||
## Available Tools
|
||||
|
||||
- [`oracle-sql`](../tools/oracle/oracle-sql.md)
|
||||
Execute pre-defined prepared SQL queries in Oracle.
|
||||
Execute pre-defined prepared SQL queries in Oracle.
|
||||
|
||||
- [`oracle-execute-sql`](../tools/oracle/oracle-execute-sql.md)
|
||||
Run parameterized SQL queries in Oracle.
|
||||
Run parameterized SQL queries in Oracle.
|
||||
|
||||
## Requirements
|
||||
|
||||
@@ -33,6 +33,25 @@ user][oracle-users] to log in to the database with the necessary permissions.
|
||||
[oracle-users]:
|
||||
https://docs.oracle.com/en/database/oracle/oracle-database/21/sqlrf/CREATE-USER.html
|
||||
|
||||
### Oracle Driver Requirement (Conditional)
|
||||
|
||||
The Oracle source offers two connection drivers:
|
||||
|
||||
1. **Pure Go Driver (`useOCI: false`, default):** Uses the `go-ora` library.
|
||||
This driver is simpler and does not require any local Oracle software
|
||||
installation, but it **lacks support for advanced features** like Oracle
|
||||
Wallets or Kerberos authentication.
|
||||
|
||||
2. **OCI-Based Driver (`useOCI: true`):** Uses the `godror` library, which
|
||||
provides access to **advanced Oracle features** like Digital Wallet support.
|
||||
|
||||
If you set `useOCI: true`, you **must** install the **Oracle Instant Client**
|
||||
libraries on the machine where this tool runs.
|
||||
|
||||
You can download the Instant Client from the official Oracle website: [Oracle
|
||||
Instant Client
|
||||
Downloads](https://www.oracle.com/database/technologies/instant-client/downloads.html)
|
||||
|
||||
## Connection Methods
|
||||
|
||||
You can configure the connection to your Oracle database using one of the
|
||||
@@ -66,12 +85,15 @@ using a TNS (Transparent Network Substrate) alias.
|
||||
containing it. This setting will override the `TNS_ADMIN` environment
|
||||
variable.
|
||||
|
||||
## Example
|
||||
## Examples
|
||||
|
||||
This example demonstrates the four connection methods you could choose from:
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
my-oracle-source:
|
||||
kind: oracle
|
||||
|
||||
# --- Choose one connection method ---
|
||||
# 1. Host, Port, and Service Name
|
||||
host: 127.0.0.1
|
||||
@@ -88,6 +110,43 @@ sources:
|
||||
user: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
|
||||
# Optional: Set to true to use the OCI-based driver for advanced features (Requires Oracle Instant Client)
|
||||
```
|
||||
|
||||
### Using an Oracle Wallet
|
||||
|
||||
Oracle Wallet allows you to store credentails used for database connection. Depending whether you are using an OCI-based driver, the wallet configuration is different.
|
||||
|
||||
#### Pure Go Driver (`useOCI: false`) - Oracle Wallet
|
||||
|
||||
The `go-ora` driver uses the `walletLocation` field to connect to a database secured with an Oracle Wallet without standard username and password.
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
pure-go-wallet:
|
||||
kind: oracle
|
||||
connectionString: "127.0.0.1:1521/XEPDB1"
|
||||
user: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
# The TNS Alias is often required to connect to a service registered in tnsnames.ora
|
||||
tnsAlias: "SECURE_DB_ALIAS"
|
||||
walletLocation: "/path/to/my/wallet/directory"
|
||||
```
|
||||
|
||||
#### OCI-Based Driver (`useOCI: true`) - Oracle Wallet
|
||||
|
||||
For the OCI-based driver, wallet authentication is triggered by setting tnsAdmin to the wallet directory and connecting via a tnsAlias.
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
oci-wallet:
|
||||
kind: oracle
|
||||
connectionString: "127.0.0.1:1521/XEPDB1"
|
||||
user: ${USER_NAME}
|
||||
password: ${PASSWORD}
|
||||
tnsAlias: "WALLET_DB_ALIAS"
|
||||
tnsAdmin: "/opt/oracle/wallet" # Directory containing tnsnames.ora, sqlnet.ora, and wallet files
|
||||
useOCI: true
|
||||
```
|
||||
|
||||
{{< notice tip >}}
|
||||
@@ -97,14 +156,15 @@ instead of hardcoding your secrets into the configuration file.
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "oracle". |
|
||||
| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). |
|
||||
| password | string | true | Password of the Oracle user (e.g. "my-password"). |
|
||||
| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. |
|
||||
| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. |
|
||||
| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. |
|
||||
| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. |
|
||||
| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. |
|
||||
| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. |
|
||||
| **field** | **type** | **required** | **description** |
|
||||
|------------------|:--------:|:------------:|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| kind | string | true | Must be "oracle". |
|
||||
| user | string | true | Name of the Oracle user to connect as (e.g. "my-oracle-user"). |
|
||||
| password | string | true | Password of the Oracle user (e.g. "my-password"). |
|
||||
| host | string | false | IP address or hostname to connect to (e.g. "127.0.0.1"). Required if not using `connectionString` or `tnsAlias`. |
|
||||
| port | integer | false | Port to connect to (e.g. "1521"). Required if not using `connectionString` or `tnsAlias`. |
|
||||
| serviceName | string | false | The Oracle service name of the database to connect to. Required if not using `connectionString` or `tnsAlias`. |
|
||||
| connectionString | string | false | A direct connection string (e.g. "hostname:port/servicename"). Use as an alternative to `host`, `port`, and `serviceName`. |
|
||||
| tnsAlias | string | false | A TNS alias from a `tnsnames.ora` file. Use as an alternative to `host`/`port` or `connectionString`. |
|
||||
| tnsAdmin | string | false | Path to the directory containing the `tnsnames.ora` file. This overrides the `TNS_ADMIN` environment variable if it is set. |
|
||||
| useOCI | bool | false | If true, uses the OCI-based driver (godror) which supports Oracle Wallet/Kerberos but requires the Oracle Instant Client libraries to be installed. Defaults to false (pure Go driver). |
|
||||
|
||||
7
docs/en/resources/tools/cloudgda/_index.md
Normal file
7
docs/en/resources/tools/cloudgda/_index.md
Normal file
@@ -0,0 +1,7 @@
|
||||
---
|
||||
title: "Gemini Data Analytics"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
Tools for Gemini Data Analytics.
|
||||
---
|
||||
92
docs/en/resources/tools/cloudgda/cloud-gda-query.md
Normal file
92
docs/en/resources/tools/cloudgda/cloud-gda-query.md
Normal file
@@ -0,0 +1,92 @@
|
||||
---
|
||||
title: "Gemini Data Analytics QueryData"
|
||||
type: docs
|
||||
weight: 1
|
||||
description: >
|
||||
A tool to convert natural language queries into SQL statements using the Gemini Data Analytics QueryData API.
|
||||
aliases:
|
||||
- /resources/tools/cloud-gemini-data-analytics-query
|
||||
---
|
||||
|
||||
## About
|
||||
|
||||
The `cloud-gemini-data-analytics-query` tool allows you to send natural language questions to the Gemini Data Analytics API and receive structured responses containing SQL queries, natural language answers, and explanations. For details on defining data agent context for database data sources, see the official [documentation](https://docs.cloud.google.com/gemini/docs/conversational-analytics-api/data-agent-authored-context-databases).
|
||||
|
||||
## Example
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
my-gda-query-tool:
|
||||
kind: cloud-gemini-data-analytics-query
|
||||
source: my-gda-source
|
||||
description: "Use this tool to send natural language queries to the Gemini Data Analytics API and receive SQL, natural language answers, and explanations."
|
||||
location: ${your_database_location}
|
||||
context:
|
||||
datasourceReferences:
|
||||
cloudSqlReference:
|
||||
databaseReference:
|
||||
projectId: "${your_project_id}"
|
||||
region: "${your_database_instance_region}"
|
||||
instanceId: "${your_database_instance_id}"
|
||||
databaseId: "${your_database_name}"
|
||||
engine: "POSTGRESQL"
|
||||
agentContextReference:
|
||||
contextSetId: "${your_context_set_id}" # E.g. projects/${project_id}/locations/${context_set_location}/contextSets/${context_set_id}
|
||||
generationOptions:
|
||||
generateQueryResult: true
|
||||
generateNaturalLanguageAnswer: true
|
||||
generateExplanation: true
|
||||
generateDisambiguationQuestion: true
|
||||
```
|
||||
|
||||
### Usage Flow
|
||||
|
||||
When using this tool, a `prompt` parameter containing a natural language query is provided to the tool (typically by an agent). The tool then interacts with the Gemini Data Analytics API using the context defined in your configuration.
|
||||
|
||||
The structure of the response depends on the `generationOptions` configured in your tool definition (e.g., enabling `generateQueryResult` will include the SQL query results).
|
||||
|
||||
See [Data Analytics API REST documentation](https://clouddocs.devsite.corp.google.com/gemini/docs/conversational-analytics-api/reference/rest/v1alpha/projects.locations/queryData?rep_location=global) for details.
|
||||
|
||||
**Example Input Prompt:**
|
||||
|
||||
```text
|
||||
How many accounts who have region in Prague are eligible for loans? A3 contains the data of region.
|
||||
```
|
||||
|
||||
**Example API Response:**
|
||||
|
||||
```json
|
||||
{
|
||||
"generatedQuery": "SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = 'Prague'",
|
||||
"intentExplanation": "I found a template that matches the user's question. The template asks about the number of accounts who have region in a given city and are eligible for loans. The question asks about the number of accounts who have region in Prague and are eligible for loans. The template's parameterized SQL is 'SELECT COUNT(T1.account_id) FROM account AS T1 INNER JOIN loan AS T2 ON T1.account_id = T2.account_id INNER JOIN district AS T3 ON T1.district_id = T3.district_id WHERE T3.A3 = ?'. I will replace the named parameter '?' with 'Prague'.",
|
||||
"naturalLanguageAnswer": "There are 84 accounts from the Prague region that are eligible for loans.",
|
||||
"queryResult": {
|
||||
"columns": [
|
||||
{
|
||||
"type": "INT64"
|
||||
}
|
||||
],
|
||||
"rows": [
|
||||
{
|
||||
"values": [
|
||||
{
|
||||
"value": "84"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"totalRowCount": "1"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Reference
|
||||
|
||||
| **field** | **type** | **required** | **description** |
|
||||
| ----------------- | :------: | :----------: | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| kind | string | true | Must be "cloud-gemini-data-analytics-query". |
|
||||
| source | string | true | The name of the `cloud-gemini-data-analytics` source to use. |
|
||||
| description | string | true | A description of the tool's purpose. |
|
||||
| location | string | true | The Google Cloud location of the target database resource (e.g., "us-central1"). This is used to construct the parent resource name in the API call. |
|
||||
| context | object | true | The context for the query, including datasource references. See [QueryDataContext](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L156) for details. |
|
||||
| generationOptions | object | false | Options for generating the response. See [GenerationOptions](https://github.com/googleapis/googleapis/blob/b32495a713a68dd0dff90cf0b24021debfca048a/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto#L135) for details. |
|
||||
8
go.mod
8
go.mod
@@ -12,7 +12,7 @@ require (
|
||||
cloud.google.com/go/dataplex v1.28.0
|
||||
cloud.google.com/go/dataproc/v2 v2.15.0
|
||||
cloud.google.com/go/firestore v1.20.0
|
||||
cloud.google.com/go/geminidataanalytics v0.2.1
|
||||
cloud.google.com/go/geminidataanalytics v0.3.0
|
||||
cloud.google.com/go/longrunning v0.7.0
|
||||
cloud.google.com/go/spanner v1.86.1
|
||||
github.com/ClickHouse/clickhouse-go/v2 v2.40.3
|
||||
@@ -33,6 +33,7 @@ require (
|
||||
github.com/go-playground/validator/v10 v10.28.0
|
||||
github.com/go-sql-driver/mysql v1.9.3
|
||||
github.com/goccy/go-yaml v1.18.0
|
||||
github.com/godror/godror v0.49.4
|
||||
github.com/google/go-cmp v0.7.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.7.6
|
||||
@@ -91,6 +92,7 @@ require (
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.29.0 // indirect
|
||||
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.54.0 // indirect
|
||||
github.com/PuerkitoBio/goquery v1.10.3 // indirect
|
||||
github.com/VictoriaMetrics/easyproto v0.1.4 // indirect
|
||||
github.com/ajg/form v1.5.1 // indirect
|
||||
github.com/apache/arrow/go/v15 v15.0.2 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
@@ -107,11 +109,13 @@ require (
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.10 // indirect
|
||||
github.com/go-jose/go-jose/v4 v4.1.2 // indirect
|
||||
github.com/go-logfmt/logfmt v0.6.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/godror/knownpb v0.3.0 // indirect
|
||||
github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect
|
||||
github.com/golang-sql/sqlexp v0.1.0 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
|
||||
@@ -181,7 +185,7 @@ require (
|
||||
golang.org/x/time v0.14.0 // indirect
|
||||
golang.org/x/tools v0.38.0 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f // indirect
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect
|
||||
google.golang.org/grpc v1.76.0 // indirect
|
||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
||||
|
||||
22
go.sum
22
go.sum
@@ -311,8 +311,8 @@ cloud.google.com/go/gaming v1.6.0/go.mod h1:YMU1GEvA39Qt3zWGyAVA9bpYz/yAhTvaQ1t2
|
||||
cloud.google.com/go/gaming v1.7.0/go.mod h1:LrB8U7MHdGgFG851iHAfqUdLcKBdQ55hzXy9xBJz0+w=
|
||||
cloud.google.com/go/gaming v1.8.0/go.mod h1:xAqjS8b7jAVW0KFYeRUxngo9My3f33kFmua++Pi+ggM=
|
||||
cloud.google.com/go/gaming v1.9.0/go.mod h1:Fc7kEmCObylSWLO334NcO+O9QMDyz+TKC4v1D7X+Bc0=
|
||||
cloud.google.com/go/geminidataanalytics v0.2.1 h1:gtG/9VlUJpL67yukFen/twkAEHliYvW7610Rlnn5rpQ=
|
||||
cloud.google.com/go/geminidataanalytics v0.2.1/go.mod h1:gIsj/ELDCzVbw24185zwjXgbzYiqdGe7TSSK2HrdtA0=
|
||||
cloud.google.com/go/geminidataanalytics v0.3.0 h1:2Wi/kqFb5OLuEGH7q+/miE19VTqK1MYHjBEHENap9HI=
|
||||
cloud.google.com/go/geminidataanalytics v0.3.0/go.mod h1:QRc0b6ywyc3Z7S3etFgslz7hippkW/jRvtops5rKqIg=
|
||||
cloud.google.com/go/gkebackup v0.2.0/go.mod h1:XKvv/4LfG829/B8B7xRkk8zRrOEbKtEam6yNfuQNH60=
|
||||
cloud.google.com/go/gkebackup v0.3.0/go.mod h1:n/E671i1aOQvUxT541aTkCwExO/bTer2HDlj4TsBRAo=
|
||||
cloud.google.com/go/gkebackup v0.4.0/go.mod h1:byAyBGUwYGEEww7xsbnUTBHIYcOPy/PgUWUtOeRm9Vg=
|
||||
@@ -683,6 +683,10 @@ github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8
|
||||
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
|
||||
github.com/PuerkitoBio/goquery v1.10.3 h1:pFYcNSqHxBD06Fpj/KsbStFRsgRATgnf3LeXiUkhzPo=
|
||||
github.com/PuerkitoBio/goquery v1.10.3/go.mod h1:tMUX0zDMHXYlAQk6p35XxQMqMweEKB7iK7iLNd4RH4Y=
|
||||
github.com/UNO-SOFT/zlog v0.8.1 h1:TEFkGJHtUfTRgMkLZiAjLSHALjwSBdw6/zByMC5GJt4=
|
||||
github.com/UNO-SOFT/zlog v0.8.1/go.mod h1:yqFOjn3OhvJ4j7ArJqQNA+9V+u6t9zSAyIZdWdMweWc=
|
||||
github.com/VictoriaMetrics/easyproto v0.1.4 h1:r8cNvo8o6sR4QShBXQd1bKw/VVLSQma/V2KhTBPf+Sc=
|
||||
github.com/VictoriaMetrics/easyproto v0.1.4/go.mod h1:QlGlzaJnDfFd8Lk6Ci/fuLxfTo3/GThPs2KH23mv710=
|
||||
github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26 h1:3YVZUqkoev4mL+aCwVOSWV4M7pN+NURHL38Z2zq5JKA=
|
||||
github.com/ahmetb/dlog v0.0.0-20170105205344-4fb5f8204f26/go.mod h1:ymXt5bw5uSNu4jveerFxE0vNYxF8ncqbptntMaFMg3k=
|
||||
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
||||
@@ -884,6 +888,8 @@ github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vb
|
||||
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
|
||||
github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk=
|
||||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
||||
github.com/go-logfmt/logfmt v0.6.0 h1:wGYYu3uicYdqXVgoYbvnkrPVXkuLM1p1ifugDMEdRi4=
|
||||
github.com/go-logfmt/logfmt v0.6.0/go.mod h1:WYhtIu8zTZfxdn5+rREduYbwxfcBr/Vr6KEVveWlfTs=
|
||||
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
|
||||
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
|
||||
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
@@ -909,6 +915,10 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/godror/godror v0.49.4 h1:8kKWKoR17nPX7u10hr4GwD4u10hzTZED9ihdkuzRrKI=
|
||||
github.com/godror/godror v0.49.4/go.mod h1:kTMcxZzRw73RT5kn9v3JkBK4kHI6dqowHotqV72ebU8=
|
||||
github.com/godror/knownpb v0.3.0 h1:+caUdy8hTtl7X05aPl3tdL540TvCcaQA6woZQroLZMw=
|
||||
github.com/godror/knownpb v0.3.0/go.mod h1:PpTyfJwiOEAzQl7NtVCM8kdPCnp3uhxsZYIzZ5PV4zU=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
@@ -1172,6 +1182,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh
|
||||
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.4 h1:7toxehVcYkZbyxV4W3Ib9VcnyRBQPucF+VwNNmtSXi4=
|
||||
github.com/neo4j/neo4j-go-driver/v5 v5.28.4/go.mod h1:Vff8OwT7QpLm7L2yYr85XNWe9Rbqlbeb9asNXJTHO4k=
|
||||
github.com/oklog/ulid/v2 v2.0.2 h1:r4fFzBm+bv0wNKNh5eXTwU7i85y5x+uwkxCUTNVQqLc=
|
||||
github.com/oklog/ulid/v2 v2.0.2/go.mod h1:mtBL0Qe/0HAx6/a4Z30qxVIAL1eQDweXq5lxOEiwQ68=
|
||||
github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8=
|
||||
github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
@@ -1671,6 +1683,8 @@ golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
@@ -1990,8 +2004,8 @@ google.golang.org/genproto v0.0.0-20230331144136-dcfb400f0633/go.mod h1:UUQDJDOl
|
||||
google.golang.org/genproto v0.0.0-20230410155749-daa745c078e1/go.mod h1:nKE/iIaLqn2bQwXBg8f1g2Ylh6r5MN5CmZvuzZCgsCU=
|
||||
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8 h1:a12a2/BiVRxRWIqBbfqoSK6tgq8cyUgMnEI81QlPge0=
|
||||
google.golang.org/genproto v0.0.0-20251022142026-3a174f9686a8/go.mod h1:1Ic78BnpzY8OaTCmzxJDP4qC9INZPbGZl+54RKjtyeI=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f h1:OiFuztEyBivVKDvguQJYWq1yDcfAHIID/FVrPR4oiI0=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251014184007-4626949a642f/go.mod h1:kprOiu9Tr0JYyD6DORrc4Hfyk3RFXqkQ3ctHEum3ZbM=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba h1:B14OtaXuMaCQsl2deSvNkyPKIzq3BjfxQp8d00QyWx4=
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba/go.mod h1:G5IanEx8/PgI9w6CFcYQf7jMtHQhZruvfM1i3qOqk5U=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 h1:tRPGkdGHuewF4UisLzzHHr1spKw92qLM98nIzxbC0wY=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
|
||||
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
|
||||
|
||||
@@ -205,10 +205,13 @@ func (s *stdioSession) readLine(ctx context.Context) (string, error) {
|
||||
}
|
||||
|
||||
// write writes to stdout with response to client
|
||||
func (s *stdioSession) write(ctx context.Context, response any) error {
|
||||
res, _ := json.Marshal(response)
|
||||
func (s *stdioSession) write(_ context.Context, response any) error {
|
||||
res, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal response to JSON: %w", err)
|
||||
}
|
||||
|
||||
_, err := fmt.Fprintf(s.writer, "%s\n", res)
|
||||
_, err = fmt.Fprintf(s.writer, "%s\n", res)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
154
internal/sources/cloudgda/cloud_gda.go
Normal file
154
internal/sources/cloudgda/cloud_gda.go
Normal file
@@ -0,0 +1,154 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package cloudgda
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const SourceKind string = "cloud-gemini-data-analytics"
|
||||
const Endpoint string = "https://geminidataanalytics.googleapis.com"
|
||||
|
||||
type userAgentRoundTripper struct {
|
||||
userAgent string
|
||||
next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *userAgentRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
ua := newReq.Header.Get("User-Agent")
|
||||
if ua == "" {
|
||||
newReq.Header.Set("User-Agent", rt.userAgent)
|
||||
} else {
|
||||
newReq.Header.Set("User-Agent", ua+" "+rt.userAgent)
|
||||
}
|
||||
return rt.next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ sources.SourceConfig = Config{}
|
||||
|
||||
func init() {
|
||||
if !sources.Register(SourceKind, newConfig) {
|
||||
panic(fmt.Sprintf("source kind %q already registered", SourceKind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources.SourceConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
ProjectID string `yaml:"projectId" validate:"required"`
|
||||
UseClientOAuth bool `yaml:"useClientOAuth"`
|
||||
}
|
||||
|
||||
func (r Config) SourceConfigKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
// Initialize initializes a Gemini Data Analytics Source instance.
|
||||
func (r Config) Initialize(ctx context.Context, tracer trace.Tracer) (sources.Source, error) {
|
||||
ua, err := util.UserAgentFromContext(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error in User Agent retrieval: %s", err)
|
||||
}
|
||||
|
||||
var client *http.Client
|
||||
if r.UseClientOAuth {
|
||||
client = &http.Client{
|
||||
Transport: &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: http.DefaultTransport,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Use Application Default Credentials
|
||||
// Scope: "https://www.googleapis.com/auth/cloud-platform" is generally sufficient for GDA
|
||||
creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find default credentials: %w", err)
|
||||
}
|
||||
baseClient := oauth2.NewClient(ctx, creds.TokenSource)
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: ua,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
client = baseClient
|
||||
}
|
||||
|
||||
s := &Source{
|
||||
Config: r,
|
||||
Client: client,
|
||||
BaseURL: Endpoint,
|
||||
userAgent: ua,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
var _ sources.Source = &Source{}
|
||||
|
||||
type Source struct {
|
||||
Config
|
||||
Client *http.Client
|
||||
BaseURL string
|
||||
userAgent string
|
||||
}
|
||||
|
||||
func (s *Source) SourceKind() string {
|
||||
return SourceKind
|
||||
}
|
||||
|
||||
func (s *Source) ToConfig() sources.SourceConfig {
|
||||
return s.Config
|
||||
}
|
||||
|
||||
func (s *Source) GetClient(ctx context.Context, accessToken string) (*http.Client, error) {
|
||||
if s.UseClientOAuth {
|
||||
if accessToken == "" {
|
||||
return nil, fmt.Errorf("client-side OAuth is enabled but no access token was provided")
|
||||
}
|
||||
token := &oauth2.Token{AccessToken: accessToken}
|
||||
baseClient := oauth2.NewClient(ctx, oauth2.StaticTokenSource(token))
|
||||
baseClient.Transport = &userAgentRoundTripper{
|
||||
userAgent: s.userAgent,
|
||||
next: baseClient.Transport,
|
||||
}
|
||||
return baseClient, nil
|
||||
}
|
||||
return s.Client, nil
|
||||
}
|
||||
|
||||
func (s *Source) UseClientAuthorization() bool {
|
||||
return s.UseClientOAuth
|
||||
}
|
||||
213
internal/sources/cloudgda/cloud_gda_test.go
Normal file
213
internal/sources/cloudgda/cloud_gda_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
)
|
||||
|
||||
func TestParseFromYamlCloudGDA(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
sources:
|
||||
my-gda-instance:
|
||||
kind: cloud-gemini-data-analytics
|
||||
projectId: test-project-id
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-gda-instance": cloudgda.Config{
|
||||
Name: "my-gda-instance",
|
||||
Kind: cloudgda.SourceKind,
|
||||
ProjectID: "test-project-id",
|
||||
UseClientOAuth: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "use client auth example",
|
||||
in: `
|
||||
sources:
|
||||
my-gda-instance:
|
||||
kind: cloud-gemini-data-analytics
|
||||
projectId: another-project
|
||||
useClientOAuth: true
|
||||
`,
|
||||
want: map[string]sources.SourceConfig{
|
||||
"my-gda-instance": cloudgda.Config{
|
||||
Name: "my-gda-instance",
|
||||
Kind: cloudgda.SourceKind,
|
||||
ProjectID: "another-project",
|
||||
UseClientOAuth: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Sources)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYaml(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "missing projectId",
|
||||
in: `
|
||||
sources:
|
||||
my-gda-instance:
|
||||
kind: cloud-gemini-data-analytics
|
||||
`,
|
||||
err: "unable to parse source \"my-gda-instance\" as \"cloud-gemini-data-analytics\": Key: 'Config.ProjectID' Error:Field validation for 'ProjectID' failed on the 'required' tag",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := err.Error()
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error: got %q, want %q", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
// Create a dummy credentials file for testing ADC
|
||||
credFile := filepath.Join(t.TempDir(), "application_default_credentials.json")
|
||||
dummyCreds := `{
|
||||
"client_id": "foo",
|
||||
"client_secret": "bar",
|
||||
"refresh_token": "baz",
|
||||
"type": "authorized_user"
|
||||
}`
|
||||
if err := os.WriteFile(credFile, []byte(dummyCreds), 0644); err != nil {
|
||||
t.Fatalf("failed to write dummy credentials file: %v", err)
|
||||
}
|
||||
t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", credFile)
|
||||
|
||||
// Use ContextWithUserAgent to avoid "unable to retrieve user agent" error
|
||||
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
|
||||
tracer := noop.NewTracerProvider().Tracer("test")
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg cloudgda.Config
|
||||
wantClientOAuth bool
|
||||
}{
|
||||
{
|
||||
desc: "initialize with ADC",
|
||||
cfg: cloudgda.Config{Name: "test-gda", Kind: cloudgda.SourceKind, ProjectID: "test-proj"},
|
||||
wantClientOAuth: false,
|
||||
},
|
||||
{
|
||||
desc: "initialize with client OAuth",
|
||||
cfg: cloudgda.Config{Name: "test-gda-oauth", Kind: cloudgda.SourceKind, ProjectID: "test-proj", UseClientOAuth: true},
|
||||
wantClientOAuth: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
src, err := tc.cfg.Initialize(ctx, tracer)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to initialize source: %v", err)
|
||||
}
|
||||
|
||||
gdaSrc, ok := src.(*cloudgda.Source)
|
||||
if !ok {
|
||||
t.Fatalf("expected *cloudgda.Source, got %T", src)
|
||||
}
|
||||
|
||||
// Check that the client is non-nil
|
||||
if gdaSrc.Client == nil && !tc.wantClientOAuth {
|
||||
t.Fatal("expected non-nil HTTP client for ADC, got nil")
|
||||
}
|
||||
// When client OAuth is true, the source's client should be initialized with a base HTTP client
|
||||
// that includes the user agent round tripper, but not the OAuth token. The token-aware
|
||||
// client is created by GetClient.
|
||||
if gdaSrc.Client == nil && tc.wantClientOAuth {
|
||||
t.Fatal("expected non-nil HTTP client for client OAuth config, got nil")
|
||||
}
|
||||
|
||||
// Test UseClientAuthorization method
|
||||
if gdaSrc.UseClientAuthorization() != tc.wantClientOAuth {
|
||||
t.Errorf("UseClientAuthorization mismatch: want %t, got %t", tc.wantClientOAuth, gdaSrc.UseClientAuthorization())
|
||||
}
|
||||
|
||||
// Test GetClient with accessToken for client OAuth scenarios
|
||||
if tc.wantClientOAuth {
|
||||
client, err := gdaSrc.GetClient(ctx, "dummy-token")
|
||||
if err != nil {
|
||||
t.Fatalf("GetClient with token failed: %v", err)
|
||||
}
|
||||
if client == nil {
|
||||
t.Fatal("expected non-nil HTTP client from GetClient with token, got nil")
|
||||
}
|
||||
// Ensure passing empty token with UseClientOAuth enabled returns error
|
||||
_, err = gdaSrc.GetClient(ctx, "")
|
||||
if err == nil || err.Error() != "client-side OAuth is enabled but no access token was provided" {
|
||||
t.Errorf("expected 'client-side OAuth is enabled but no access token was provided' error, got: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,9 +9,11 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
_ "github.com/godror/godror" // OCI driver
|
||||
_ "github.com/sijms/go-ora/v2" // Pure Go driver
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
_ "github.com/sijms/go-ora/v2"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -32,7 +34,7 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate that we have one of: tns_alias, connection_string, or host+service_name
|
||||
// Validate that we have one of: tnsAlias, connectionString, or host+service_name
|
||||
if err := actual.validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid Oracle configuration: %w", err)
|
||||
}
|
||||
@@ -43,21 +45,24 @@ func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (sources
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
ConnectionString string `yaml:"connectionString,omitempty"` // Direct connection string (hostname[:port]/servicename)
|
||||
TnsAlias string `yaml:"tnsAlias,omitempty"` // TNS alias from tnsnames.ora
|
||||
Host string `yaml:"host,omitempty"` // Optional when using connectionString/tnsAlias
|
||||
Port int `yaml:"port,omitempty"` // Explicit port support
|
||||
ServiceName string `yaml:"serviceName,omitempty"` // Optional when using connectionString/tnsAlias
|
||||
ConnectionString string `yaml:"connectionString,omitempty"`
|
||||
TnsAlias string `yaml:"tnsAlias,omitempty"`
|
||||
TnsAdmin string `yaml:"tnsAdmin,omitempty"`
|
||||
Host string `yaml:"host,omitempty"`
|
||||
Port int `yaml:"port,omitempty"`
|
||||
ServiceName string `yaml:"serviceName,omitempty"`
|
||||
User string `yaml:"user" validate:"required"`
|
||||
Password string `yaml:"password" validate:"required"`
|
||||
TnsAdmin string `yaml:"tnsAdmin,omitempty"` // Optional: override TNS_ADMIN environment variable
|
||||
UseOCI bool `yaml:"useOCI,omitempty"`
|
||||
WalletLocation string `yaml:"walletLocation,omitempty"`
|
||||
}
|
||||
|
||||
// validate ensures we have one of: tns_alias, connection_string, or host+service_name
|
||||
func (c Config) validate() error {
|
||||
hasTnsAdmin := strings.TrimSpace(c.TnsAdmin) != ""
|
||||
hasTnsAlias := strings.TrimSpace(c.TnsAlias) != ""
|
||||
hasConnStr := strings.TrimSpace(c.ConnectionString) != ""
|
||||
hasHostService := strings.TrimSpace(c.Host) != "" && strings.TrimSpace(c.ServiceName) != ""
|
||||
hasWallet := strings.TrimSpace(c.WalletLocation) != ""
|
||||
|
||||
connectionMethods := 0
|
||||
if hasTnsAlias {
|
||||
@@ -78,6 +83,14 @@ func (c Config) validate() error {
|
||||
return fmt.Errorf("provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'")
|
||||
}
|
||||
|
||||
if hasTnsAdmin && !c.UseOCI {
|
||||
return fmt.Errorf("`tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead")
|
||||
}
|
||||
|
||||
if hasWallet && c.UseOCI {
|
||||
return fmt.Errorf("when using an OCI driver, use `tnsAdmin` to specify credentials file location instead")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -132,7 +145,8 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Set TNS_ADMIN environment variable if specified in config.
|
||||
hasWallet := strings.TrimSpace(config.WalletLocation) != ""
|
||||
|
||||
if config.TnsAdmin != "" {
|
||||
originalTnsAdmin := os.Getenv("TNS_ADMIN")
|
||||
os.Setenv("TNS_ADMIN", config.TnsAdmin)
|
||||
@@ -147,28 +161,49 @@ func initOracleConnection(ctx context.Context, tracer trace.Tracer, config Confi
|
||||
}()
|
||||
}
|
||||
|
||||
var serverString string
|
||||
var connectStringBase string
|
||||
if config.TnsAlias != "" {
|
||||
// Use TNS alias
|
||||
serverString = strings.TrimSpace(config.TnsAlias)
|
||||
connectStringBase = strings.TrimSpace(config.TnsAlias)
|
||||
} else if config.ConnectionString != "" {
|
||||
// Use provided connection string directly (hostname[:port]/servicename format)
|
||||
serverString = strings.TrimSpace(config.ConnectionString)
|
||||
connectStringBase = strings.TrimSpace(config.ConnectionString)
|
||||
} else {
|
||||
// Build connection string from host and service_name
|
||||
if config.Port > 0 {
|
||||
serverString = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName)
|
||||
connectStringBase = fmt.Sprintf("%s:%d/%s", config.Host, config.Port, config.ServiceName)
|
||||
} else {
|
||||
serverString = fmt.Sprintf("%s/%s", config.Host, config.ServiceName)
|
||||
connectStringBase = fmt.Sprintf("%s/%s", config.Host, config.ServiceName)
|
||||
}
|
||||
}
|
||||
|
||||
connStr := fmt.Sprintf("oracle://%s:%s@%s",
|
||||
config.User, config.Password, serverString)
|
||||
var driverName string
|
||||
var finalConnStr string
|
||||
|
||||
db, err := sql.Open("oracle", connStr)
|
||||
if config.UseOCI {
|
||||
// Use godror driver (requires OCI)
|
||||
driverName = "godror"
|
||||
finalConnStr = fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
|
||||
config.User, config.Password, connectStringBase)
|
||||
logger.DebugContext(ctx, fmt.Sprintf("Using godror driver (OCI-based) with connectString: %s\n", connectStringBase))
|
||||
} else {
|
||||
// Use go-ora driver (pure Go)
|
||||
driverName = "oracle"
|
||||
|
||||
user := config.User
|
||||
password := config.Password
|
||||
|
||||
if hasWallet {
|
||||
finalConnStr = fmt.Sprintf("oracle://%s:%s@%s?ssl=true&wallet=%s",
|
||||
user, password, connectStringBase, config.WalletLocation)
|
||||
} else {
|
||||
// Standard go-ora connection
|
||||
finalConnStr = fmt.Sprintf("oracle://%s:%s@%s",
|
||||
config.User, config.Password, connectStringBase)
|
||||
logger.DebugContext(ctx, fmt.Sprintf("Using go-ora driver (pure-Go) with serverString: %s\n", connectStringBase))
|
||||
}
|
||||
}
|
||||
|
||||
db, err := sql.Open(driverName, finalConnStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
|
||||
return nil, fmt.Errorf("unable to open Oracle connection with driver %s: %w", driverName, err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
|
||||
200
internal/sources/oracle/oracle_test.go
Normal file
200
internal/sources/oracle/oracle_test.go
Normal file
@@ -0,0 +1,200 @@
|
||||
// Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
package oracle_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources/oracle"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
)
|
||||
|
||||
func TestParseFromYamlOracle(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.SourceConfigs
|
||||
}{
|
||||
{
|
||||
desc: "connection string and useOCI=true",
|
||||
in: `
|
||||
sources:
|
||||
my-oracle-cs:
|
||||
kind: oracle
|
||||
connectionString: "my-host:1521/XEPDB1"
|
||||
user: my_user
|
||||
password: my_pass
|
||||
useOCI: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-oracle-cs": oracle.Config{
|
||||
Name: "my-oracle-cs",
|
||||
Kind: oracle.SourceKind,
|
||||
ConnectionString: "my-host:1521/XEPDB1",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
UseOCI: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "host/port/serviceName and default useOCI=false",
|
||||
in: `
|
||||
sources:
|
||||
my-oracle-host:
|
||||
kind: oracle
|
||||
host: my-host
|
||||
port: 1521
|
||||
serviceName: ORCLPDB
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-oracle-host": oracle.Config{
|
||||
Name: "my-oracle-host",
|
||||
Kind: oracle.SourceKind,
|
||||
Host: "my-host",
|
||||
Port: 1521,
|
||||
ServiceName: "ORCLPDB",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
UseOCI: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tnsAlias and TnsAdmin specified with explicit useOCI=true",
|
||||
in: `
|
||||
sources:
|
||||
my-oracle-tns-oci:
|
||||
kind: oracle
|
||||
tnsAlias: FINANCE_DB
|
||||
tnsAdmin: /opt/oracle/network/admin
|
||||
user: my_user
|
||||
password: my_pass
|
||||
useOCI: true
|
||||
`,
|
||||
want: server.SourceConfigs{
|
||||
"my-oracle-tns-oci": oracle.Config{
|
||||
Name: "my-oracle-tns-oci",
|
||||
Kind: oracle.SourceKind,
|
||||
TnsAlias: "FINANCE_DB",
|
||||
TnsAdmin: "/opt/oracle/network/admin",
|
||||
User: "my_user",
|
||||
Password: "my_pass",
|
||||
UseOCI: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Sources) {
|
||||
t.Fatalf("incorrect parse:\nwant: %v\ngot: %v\ndiff: %s", tc.want, got.Sources, cmp.Diff(tc.want, got.Sources))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFailParseFromYamlOracle(t *testing.T) {
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "extra field",
|
||||
in: `
|
||||
sources:
|
||||
my-oracle-instance:
|
||||
kind: oracle
|
||||
host: my-host
|
||||
serviceName: ORCL
|
||||
user: my_user
|
||||
password: my_pass
|
||||
extraField: value
|
||||
`,
|
||||
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": [1:1] unknown field \"extraField\"\n> 1 | extraField: value\n ^\n 2 | host: my-host\n 3 | kind: oracle\n 4 | password: my_pass\n 5 | ",
|
||||
},
|
||||
{
|
||||
desc: "missing required password field",
|
||||
in: `
|
||||
sources:
|
||||
my-oracle-instance:
|
||||
kind: oracle
|
||||
host: my-host
|
||||
serviceName: ORCL
|
||||
user: my_user
|
||||
`,
|
||||
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": Key: 'Config.Password' Error:Field validation for 'Password' failed on the 'required' tag",
|
||||
},
|
||||
{
|
||||
desc: "missing connection method fields (validate fails)",
|
||||
in: `
|
||||
sources:
|
||||
my-oracle-instance:
|
||||
kind: oracle
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: must provide one of: 'tns_alias', 'connection_string', or both 'host' and 'service_name'",
|
||||
},
|
||||
{
|
||||
desc: "multiple connection methods provided (validate fails)",
|
||||
in: `
|
||||
sources:
|
||||
my-oracle-instance:
|
||||
kind: oracle
|
||||
host: my-host
|
||||
serviceName: ORCL
|
||||
connectionString: "my-host:1521/XEPDB1"
|
||||
user: my_user
|
||||
password: my_pass
|
||||
`,
|
||||
err: "unable to parse source \"my-oracle-instance\" as \"oracle\": invalid Oracle configuration: provide only one connection method: 'tns_alias', 'connection_string', or 'host'+'service_name'",
|
||||
},
|
||||
{
|
||||
desc: "fail on tnsAdmin with useOCI=false",
|
||||
in: `
|
||||
sources:
|
||||
my-oracle-fail:
|
||||
kind: oracle
|
||||
tnsAlias: FINANCE_DB
|
||||
tnsAdmin: /opt/oracle/network/admin
|
||||
user: my_user
|
||||
password: my_pass
|
||||
useOCI: false
|
||||
`,
|
||||
err: "unable to parse source \"my-oracle-fail\" as \"oracle\": invalid Oracle configuration: `tnsAdmin` can only be used when `UseOCI` is true, or use `walletLocation` instead",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Sources server.SourceConfigs `yaml:"sources"`
|
||||
}{}
|
||||
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err == nil {
|
||||
t.Fatalf("expect parsing to fail")
|
||||
}
|
||||
errStr := strings.ReplaceAll(err.Error(), "\r", "")
|
||||
|
||||
if errStr != tc.err {
|
||||
t.Fatalf("unexpected error:\ngot:\n%q\nwant:\n%q\n", errStr, tc.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -46,6 +46,11 @@ func ContextWithNewLogger() (context.Context, error) {
|
||||
return util.WithLogger(ctx, logger), nil
|
||||
}
|
||||
|
||||
// ContextWithUserAgent creates a new context with a specified user agent string.
|
||||
func ContextWithUserAgent(ctx context.Context, userAgent string) context.Context {
|
||||
return util.WithUserAgent(ctx, userAgent)
|
||||
}
|
||||
|
||||
// WaitForString waits until the server logs a single line that matches the provided regex.
|
||||
// returns the output of whatever the server sent so far.
|
||||
func WaitForString(ctx context.Context, re *regexp.Regexp, pr io.ReadCloser) (string, error) {
|
||||
|
||||
205
internal/tools/cloudgda/cloudgda.go
Normal file
205
internal/tools/cloudgda/cloudgda.go
Normal file
@@ -0,0 +1,205 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/goccy/go-yaml"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
const kind string = "cloud-gemini-data-analytics-query"
|
||||
|
||||
func init() {
|
||||
if !tools.Register(kind, newConfig) {
|
||||
panic(fmt.Sprintf("tool kind %q already registered", kind))
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig(ctx context.Context, name string, decoder *yaml.Decoder) (tools.ToolConfig, error) {
|
||||
actual := Config{Name: name}
|
||||
if err := decoder.DecodeContext(ctx, &actual); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Name string `yaml:"name" validate:"required"`
|
||||
Kind string `yaml:"kind" validate:"required"`
|
||||
Source string `yaml:"source" validate:"required"`
|
||||
Description string `yaml:"description" validate:"required"`
|
||||
Location string `yaml:"location" validate:"required"`
|
||||
Context *QueryDataContext `yaml:"context" validate:"required"`
|
||||
GenerationOptions *GenerationOptions `yaml:"generationOptions,omitempty"`
|
||||
AuthRequired []string `yaml:"authRequired"`
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.ToolConfig = Config{}
|
||||
|
||||
func (cfg Config) ToolConfigKind() string {
|
||||
return kind
|
||||
}
|
||||
|
||||
func (cfg Config) Initialize(srcs map[string]sources.Source) (tools.Tool, error) {
|
||||
// verify source exists
|
||||
rawS, ok := srcs[cfg.Source]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no source named %q configured", cfg.Source)
|
||||
}
|
||||
|
||||
// verify the source is compatible
|
||||
s, ok := rawS.(*cloudgdasrc.Source)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid source for %q tool: source kind must be `cloud-gemini-data-analytics`", kind)
|
||||
}
|
||||
|
||||
// Define the parameters for the Gemini Data Analytics Query API
|
||||
// The prompt is the only input parameter.
|
||||
allParameters := parameters.Parameters{
|
||||
parameters.NewStringParameterWithRequired("prompt", "The natural language question to ask.", true),
|
||||
}
|
||||
|
||||
mcpManifest := tools.GetMcpManifest(cfg.Name, cfg.Description, cfg.AuthRequired, allParameters, nil)
|
||||
|
||||
return Tool{
|
||||
Config: cfg,
|
||||
AllParams: allParameters,
|
||||
Source: s,
|
||||
manifest: tools.Manifest{Description: cfg.Description, Parameters: allParameters.Manifest(), AuthRequired: cfg.AuthRequired},
|
||||
mcpManifest: mcpManifest,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// validate interface
|
||||
var _ tools.Tool = Tool{}
|
||||
|
||||
type Tool struct {
|
||||
Config
|
||||
AllParams parameters.Parameters
|
||||
Source *cloudgdasrc.Source
|
||||
manifest tools.Manifest
|
||||
mcpManifest tools.McpManifest
|
||||
}
|
||||
|
||||
func (t Tool) ToConfig() tools.ToolConfig {
|
||||
return t.Config
|
||||
}
|
||||
|
||||
// Invoke executes the tool logic
|
||||
func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, params parameters.ParamValues, accessToken tools.AccessToken) (any, error) {
|
||||
paramsMap := params.AsMap()
|
||||
prompt, ok := paramsMap["prompt"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("prompt parameter not found or not a string")
|
||||
}
|
||||
|
||||
// The API endpoint itself always uses the "global" location.
|
||||
apiLocation := "global"
|
||||
apiParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, apiLocation)
|
||||
apiURL := fmt.Sprintf("%s/v1beta/%s:queryData", t.Source.BaseURL, apiParent)
|
||||
|
||||
// The parent in the request payload uses the tool's configured location.
|
||||
payloadParent := fmt.Sprintf("projects/%s/locations/%s", t.Source.ProjectID, t.Location)
|
||||
|
||||
payload := &QueryDataRequest{
|
||||
Parent: payloadParent,
|
||||
Prompt: prompt,
|
||||
Context: t.Context,
|
||||
GenerationOptions: t.GenerationOptions,
|
||||
}
|
||||
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request payload: %w", err)
|
||||
}
|
||||
|
||||
// Parse the access token if provided
|
||||
var tokenStr string
|
||||
if t.RequiresClientAuthorization(resourceMgr) {
|
||||
var err error
|
||||
tokenStr, err = accessToken.ParseBearerToken()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing access token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
client, err := t.Source.GetClient(ctx, tokenStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get HTTP client: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewBuffer(bodyBytes))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t Tool) ParseParams(data map[string]any, claims map[string]map[string]any) (parameters.ParamValues, error) {
|
||||
return parameters.ParseParams(t.AllParams, data, claims)
|
||||
}
|
||||
|
||||
func (t Tool) Manifest() tools.Manifest {
|
||||
return t.manifest
|
||||
}
|
||||
|
||||
func (t Tool) McpManifest() tools.McpManifest {
|
||||
return t.mcpManifest
|
||||
}
|
||||
|
||||
func (t Tool) Authorized(verifiedAuthServices []string) bool {
|
||||
return tools.IsAuthorized(t.AuthRequired, verifiedAuthServices)
|
||||
}
|
||||
|
||||
func (t Tool) RequiresClientAuthorization(resourceMgr tools.SourceProvider) bool {
|
||||
return t.Source.UseClientAuthorization()
|
||||
}
|
||||
|
||||
func (t Tool) GetAuthTokenHeaderName() string {
|
||||
return "Authorization"
|
||||
}
|
||||
379
internal/tools/cloudgda/cloudgda_test.go
Normal file
379
internal/tools/cloudgda/cloudgda_test.go
Normal file
@@ -0,0 +1,379 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/sources"
|
||||
cloudgdasrc "github.com/googleapis/genai-toolbox/internal/sources/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools"
|
||||
cloudgdatool "github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/internal/util/parameters"
|
||||
)
|
||||
|
||||
func TestParseFromYaml(t *testing.T) {
|
||||
t.Parallel()
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example",
|
||||
in: `
|
||||
tools:
|
||||
my-gda-query-tool:
|
||||
kind: cloud-gemini-data-analytics-query
|
||||
source: gda-api-source
|
||||
description: Test Description
|
||||
location: us-central1
|
||||
context:
|
||||
datasourceReferences:
|
||||
spannerReference:
|
||||
databaseReference:
|
||||
projectId: "cloud-db-nl2sql"
|
||||
region: "us-central1"
|
||||
instanceId: "evalbench"
|
||||
databaseId: "financial"
|
||||
engine: "GOOGLE_SQL"
|
||||
agentContextReference:
|
||||
contextSetId: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates"
|
||||
generationOptions:
|
||||
generateQueryResult: true
|
||||
`,
|
||||
want: map[string]tools.ToolConfig{
|
||||
"my-gda-query-tool": cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "gda-api-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
AuthRequired: []string{},
|
||||
Context: &cloudgdatool.QueryDataContext{
|
||||
DatasourceReferences: &cloudgdatool.DatasourceReferences{
|
||||
SpannerReference: &cloudgdatool.SpannerReference{
|
||||
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
|
||||
ProjectID: "cloud-db-nl2sql",
|
||||
Region: "us-central1",
|
||||
InstanceID: "evalbench",
|
||||
DatabaseID: "financial",
|
||||
Engine: cloudgdatool.SpannerEngineGoogleSQL,
|
||||
},
|
||||
AgentContextReference: &cloudgdatool.AgentContextReference{
|
||||
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
GenerationOptions: &cloudgdatool.GenerationOptions{
|
||||
GenerateQueryResult: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
// Parse contents
|
||||
err := yaml.Unmarshal(testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if !cmp.Equal(tc.want, got.Tools) {
|
||||
t.Fatalf("incorrect parse: want %v, got %v", tc.want, got.Tools)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// authRoundTripper is a mock http.RoundTripper that adds a dummy Authorization header.
|
||||
type authRoundTripper struct {
|
||||
Token string
|
||||
Next http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
newReq := *req
|
||||
newReq.Header = make(http.Header)
|
||||
for k, v := range req.Header {
|
||||
newReq.Header[k] = v
|
||||
}
|
||||
newReq.Header.Set("Authorization", rt.Token)
|
||||
if rt.Next == nil {
|
||||
return http.DefaultTransport.RoundTrip(&newReq)
|
||||
}
|
||||
return rt.Next.RoundTrip(&newReq)
|
||||
}
|
||||
|
||||
type mockSource struct {
|
||||
kind string
|
||||
client *http.Client // Can be used to inject a specific client
|
||||
baseURL string // BaseURL is needed to implement sources.Source.BaseURL
|
||||
config cloudgdasrc.Config // to return from ToConfig
|
||||
}
|
||||
|
||||
func (m *mockSource) SourceKind() string { return m.kind }
|
||||
func (m *mockSource) ToConfig() sources.SourceConfig { return m.config }
|
||||
func (m *mockSource) GetClient(ctx context.Context, token string) (*http.Client, error) {
|
||||
if m.client != nil {
|
||||
return m.client, nil
|
||||
}
|
||||
// Default client for testing if not explicitly set
|
||||
transport := &http.Transport{}
|
||||
authTransport := &authRoundTripper{
|
||||
Token: "Bearer test-access-token", // Dummy token
|
||||
Next: transport,
|
||||
}
|
||||
return &http.Client{Transport: authTransport}, nil
|
||||
}
|
||||
func (m *mockSource) UseClientAuthorization() bool { return false }
|
||||
func (m *mockSource) Initialize(ctx context.Context, tracer interface{}) (sources.Source, error) {
|
||||
return m, nil
|
||||
}
|
||||
func (m *mockSource) BaseURL() string { return m.baseURL }
|
||||
|
||||
func TestInitialize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
srcs := map[string]sources.Source{
|
||||
"gda-api-source": &cloudgdasrc.Source{
|
||||
Config: cloudgdasrc.Config{Name: "gda-api-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
|
||||
Client: &http.Client{},
|
||||
BaseURL: cloudgdasrc.Endpoint,
|
||||
},
|
||||
}
|
||||
|
||||
tcs := []struct {
|
||||
desc string
|
||||
cfg cloudgdatool.Config
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
desc: "successful initialization",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "gda-api-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
desc: "missing source",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "non-existent-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
desc: "incompatible source kind",
|
||||
cfg: cloudgdatool.Config{
|
||||
Name: "my-gda-query-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "incompatible-source",
|
||||
Description: "Test Description",
|
||||
Location: "us-central1",
|
||||
},
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
// Add an incompatible source for testing
|
||||
srcs["incompatible-source"] = &mockSource{kind: "another-kind"}
|
||||
|
||||
for _, tc := range tcs {
|
||||
tc := tc
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tool, err := tc.cfg.Initialize(srcs)
|
||||
if tc.expectErr && err == nil {
|
||||
t.Fatalf("expected an error but got none")
|
||||
}
|
||||
if !tc.expectErr && err != nil {
|
||||
t.Fatalf("did not expect an error but got: %v", err)
|
||||
}
|
||||
if !tc.expectErr {
|
||||
// Basic sanity check on the returned tool
|
||||
_ = tool // Avoid unused variable error
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvoke(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Mock the HTTP client and server for Invoke testing
|
||||
serverMux := http.NewServeMux()
|
||||
// Update expected URL path to include the location "us-central1"
|
||||
serverMux.HandleFunc("/v1beta/projects/test-project/locations/global:queryData", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("expected POST method, got %s", r.Method)
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Content-Type") != "application/json" {
|
||||
t.Errorf("expected Content-Type application/json, got %s", r.Header.Get("Content-Type"))
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Read and unmarshal the request body
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
t.Errorf("failed to read request body: %v", err)
|
||||
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var reqPayload cloudgdatool.QueryDataRequest
|
||||
if err := json.Unmarshal(bodyBytes, &reqPayload); err != nil {
|
||||
t.Errorf("failed to unmarshal request payload: %v", err)
|
||||
http.Error(w, "Bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify expected fields
|
||||
if r.Header.Get("Authorization") == "" {
|
||||
t.Errorf("expected Authorization header, got empty")
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if reqPayload.Prompt != "How many accounts who have region in Prague are eligible for loans?" {
|
||||
t.Errorf("unexpected prompt: %s", reqPayload.Prompt)
|
||||
}
|
||||
|
||||
// Verify payload's parent uses the tool's configured location
|
||||
if reqPayload.Parent != fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1") {
|
||||
t.Errorf("unexpected payload parent: got %q, want %q", reqPayload.Parent, fmt.Sprintf("projects/%s/locations/%s", "test-project", "us-central1"))
|
||||
}
|
||||
|
||||
// Verify context from config
|
||||
if reqPayload.Context == nil ||
|
||||
reqPayload.Context.DatasourceReferences == nil ||
|
||||
reqPayload.Context.DatasourceReferences.SpannerReference == nil ||
|
||||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference == nil ||
|
||||
reqPayload.Context.DatasourceReferences.SpannerReference.DatabaseReference.ProjectID != "cloud-db-nl2sql" {
|
||||
t.Errorf("unexpected context: %v", reqPayload.Context)
|
||||
}
|
||||
|
||||
// Verify generation options from config
|
||||
if reqPayload.GenerationOptions == nil || !reqPayload.GenerationOptions.GenerateQueryResult {
|
||||
t.Errorf("unexpected generation options: %v", reqPayload.GenerationOptions)
|
||||
}
|
||||
|
||||
// Simulate a successful response
|
||||
resp := map[string]any{
|
||||
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
|
||||
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
})
|
||||
|
||||
mockServer := httptest.NewServer(serverMux)
|
||||
defer mockServer.Close()
|
||||
|
||||
ctx := testutils.ContextWithUserAgent(context.Background(), "test-user-agent")
|
||||
|
||||
// Create an authenticated client that uses the mock server
|
||||
authTransport := &authRoundTripper{
|
||||
Token: "Bearer test-access-token",
|
||||
Next: mockServer.Client().Transport,
|
||||
}
|
||||
authClient := &http.Client{Transport: authTransport}
|
||||
|
||||
// Create a real cloudgdasrc.Source but inject the authenticated client
|
||||
mockGdaSource := &cloudgdasrc.Source{
|
||||
Config: cloudgdasrc.Config{Name: "mock-gda-source", Kind: cloudgdasrc.SourceKind, ProjectID: "test-project"},
|
||||
Client: authClient,
|
||||
BaseURL: mockServer.URL,
|
||||
}
|
||||
srcs := map[string]sources.Source{
|
||||
"mock-gda-source": mockGdaSource,
|
||||
}
|
||||
|
||||
// Initialize the tool config with context
|
||||
toolCfg := cloudgdatool.Config{
|
||||
Name: "query-data-tool",
|
||||
Kind: "cloud-gemini-data-analytics-query",
|
||||
Source: "mock-gda-source",
|
||||
Description: "Query Gemini Data Analytics",
|
||||
Location: "us-central1", // Set location for the test
|
||||
Context: &cloudgdatool.QueryDataContext{
|
||||
DatasourceReferences: &cloudgdatool.DatasourceReferences{
|
||||
SpannerReference: &cloudgdatool.SpannerReference{
|
||||
DatabaseReference: &cloudgdatool.SpannerDatabaseReference{
|
||||
ProjectID: "cloud-db-nl2sql",
|
||||
Region: "us-central1",
|
||||
InstanceID: "evalbench",
|
||||
DatabaseID: "financial",
|
||||
Engine: cloudgdatool.SpannerEngineGoogleSQL,
|
||||
},
|
||||
AgentContextReference: &cloudgdatool.AgentContextReference{
|
||||
ContextSetID: "projects/cloud-db-nl2sql/locations/us-east1/contextSets/bdf_gsql_gemini_all_templates",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
GenerationOptions: &cloudgdatool.GenerationOptions{
|
||||
GenerateQueryResult: true,
|
||||
},
|
||||
}
|
||||
|
||||
tool, err := toolCfg.Initialize(srcs)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to initialize tool: %v", err)
|
||||
}
|
||||
|
||||
// Prepare parameters for invocation - ONLY prompt
|
||||
params := parameters.ParamValues{
|
||||
{Name: "prompt", Value: "How many accounts who have region in Prague are eligible for loans?"},
|
||||
}
|
||||
|
||||
// Invoke the tool
|
||||
result, err := tool.Invoke(ctx, nil, params, "") // No accessToken needed for ADC client
|
||||
if err != nil {
|
||||
t.Fatalf("tool invocation failed: %v", err)
|
||||
}
|
||||
|
||||
// Validate the result
|
||||
expectedResult := map[string]any{
|
||||
"queryResult": "SELECT count(*) FROM accounts WHERE region = 'Prague' AND eligible_for_loans = true;",
|
||||
"naturalLanguageAnswer": "There are 5 accounts in Prague eligible for loans.",
|
||||
}
|
||||
|
||||
if !cmp.Equal(expectedResult, result) {
|
||||
t.Errorf("unexpected result: got %v, want %v", result, expectedResult)
|
||||
}
|
||||
}
|
||||
116
internal/tools/cloudgda/types.go
Normal file
116
internal/tools/cloudgda/types.go
Normal file
@@ -0,0 +1,116 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda
|
||||
|
||||
// See full service definition at: https://github.com/googleapis/googleapis/blob/master/google/cloud/geminidataanalytics/v1beta/data_chat_service.proto
|
||||
|
||||
// QueryDataRequest represents the JSON body for the queryData API
|
||||
type QueryDataRequest struct {
|
||||
Parent string `json:"parent"`
|
||||
Prompt string `json:"prompt"`
|
||||
Context *QueryDataContext `json:"context,omitempty"`
|
||||
GenerationOptions *GenerationOptions `json:"generationOptions,omitempty"`
|
||||
}
|
||||
|
||||
// QueryDataContext reflects the proto definition for the query context.
|
||||
type QueryDataContext struct {
|
||||
DatasourceReferences *DatasourceReferences `json:"datasourceReferences,omitempty" yaml:"datasourceReferences,omitempty"`
|
||||
}
|
||||
|
||||
// DatasourceReferences reflects the proto definition for datasource references, using a oneof.
|
||||
type DatasourceReferences struct {
|
||||
SpannerReference *SpannerReference `json:"spannerReference,omitempty" yaml:"spannerReference,omitempty"`
|
||||
AlloyDBReference *AlloyDBReference `json:"alloydb,omitempty" yaml:"alloydb,omitempty"`
|
||||
CloudSQLReference *CloudSQLReference `json:"cloudSqlReference,omitempty" yaml:"cloudSqlReference,omitempty"`
|
||||
}
|
||||
|
||||
// SpannerReference reflects the proto definition for Spanner database reference.
|
||||
type SpannerReference struct {
|
||||
DatabaseReference *SpannerDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
||||
}
|
||||
|
||||
// SpannerDatabaseReference reflects the proto definition for a Spanner database reference.
|
||||
type SpannerDatabaseReference struct {
|
||||
Engine SpannerEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
||||
}
|
||||
|
||||
// SpannerEngine represents the engine of the Spanner instance.
|
||||
type SpannerEngine string
|
||||
|
||||
const (
|
||||
SpannerEngineUnspecified SpannerEngine = "ENGINE_UNSPECIFIED"
|
||||
SpannerEngineGoogleSQL SpannerEngine = "GOOGLE_SQL"
|
||||
SpannerEnginePostgreSQL SpannerEngine = "POSTGRESQL"
|
||||
)
|
||||
|
||||
// AlloyDBReference reflects the proto definition for an AlloyDB database reference.
|
||||
type AlloyDBReference struct {
|
||||
DatabaseReference *AlloyDBDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
||||
}
|
||||
|
||||
// AlloyDBDatabaseReference reflects the proto definition for an AlloyDB database reference.
|
||||
type AlloyDBDatabaseReference struct {
|
||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
||||
ClusterID string `json:"clusterId,omitempty" yaml:"clusterId,omitempty"`
|
||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
||||
}
|
||||
|
||||
// CloudSQLReference reflects the proto definition for a Cloud SQL database reference.
|
||||
type CloudSQLReference struct {
|
||||
DatabaseReference *CloudSQLDatabaseReference `json:"databaseReference,omitempty" yaml:"databaseReference,omitempty"`
|
||||
AgentContextReference *AgentContextReference `json:"agentContextReference,omitempty" yaml:"agentContextReference,omitempty"`
|
||||
}
|
||||
|
||||
// CloudSQLDatabaseReference reflects the proto definition for a Cloud SQL database reference.
|
||||
type CloudSQLDatabaseReference struct {
|
||||
Engine CloudSQLEngine `json:"engine,omitempty" yaml:"engine,omitempty"`
|
||||
ProjectID string `json:"projectId,omitempty" yaml:"projectId,omitempty"`
|
||||
Region string `json:"region,omitempty" yaml:"region,omitempty"`
|
||||
InstanceID string `json:"instanceId,omitempty" yaml:"instanceId,omitempty"`
|
||||
DatabaseID string `json:"databaseId,omitempty" yaml:"databaseId,omitempty"`
|
||||
TableIDs []string `json:"tableIds,omitempty" yaml:"tableIds,omitempty"`
|
||||
}
|
||||
|
||||
// CloudSQLEngine represents the engine of the Cloud SQL instance.
|
||||
type CloudSQLEngine string
|
||||
|
||||
const (
|
||||
CloudSQLEngineUnspecified CloudSQLEngine = "ENGINE_UNSPECIFIED"
|
||||
CloudSQLEnginePostgreSQL CloudSQLEngine = "POSTGRESQL"
|
||||
CloudSQLEngineMySQL CloudSQLEngine = "MYSQL"
|
||||
)
|
||||
|
||||
// AgentContextReference reflects the proto definition for agent context.
|
||||
type AgentContextReference struct {
|
||||
ContextSetID string `json:"contextSetId,omitempty" yaml:"contextSetId,omitempty"`
|
||||
}
|
||||
|
||||
// GenerationOptions reflects the proto definition for generation options.
|
||||
type GenerationOptions struct {
|
||||
GenerateQueryResult bool `json:"generateQueryResult" yaml:"generateQueryResult"`
|
||||
GenerateNaturalLanguageAnswer bool `json:"generateNaturalLanguageAnswer" yaml:"generateNaturalLanguageAnswer"`
|
||||
GenerateExplanation bool `json:"generateExplanation" yaml:"generateExplanation"`
|
||||
GenerateDisambiguationQuestion bool `json:"generateDisambiguationQuestion" yaml:"generateDisambiguationQuestion"`
|
||||
}
|
||||
@@ -158,7 +158,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
dashboard_id := paramsMap["dashboard_id"].(string)
|
||||
name := paramsMap["name"].(string)
|
||||
title := paramsMap["title"].(string)
|
||||
filterType := paramsMap["flter_type"].(string)
|
||||
filterType := paramsMap["filter_type"].(string)
|
||||
switch filterType {
|
||||
case "date_filter":
|
||||
case "number_filter":
|
||||
|
||||
@@ -110,7 +110,7 @@ func (t Tool) Invoke(ctx context.Context, resourceMgr tools.SourceProvider, para
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting logger: %s", err)
|
||||
}
|
||||
logger.DebugContext(ctx, fmt.Sprintf("executing `%s` tool query: %s", kind, sqlParam))
|
||||
logger.DebugContext(ctx, "executing `%s` tool query: %s", kind, sqlParam)
|
||||
|
||||
results, err := t.Pool.QueryContext(ctx, sqlParam)
|
||||
if err != nil {
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
// Copyright © 2025, Oracle and/or its affiliates.
|
||||
|
||||
package oracleexecutesql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/oracle/oracleexecutesql"
|
||||
)
|
||||
|
||||
func TestParseFromYamlOracleExecuteSql(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example with auth",
|
||||
in: `
|
||||
tools:
|
||||
run_adhoc_query:
|
||||
kind: oracle-execute-sql
|
||||
source: my-oracle-instance
|
||||
description: Executes arbitrary SQL statements like INSERT or UPDATE.
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"run_adhoc_query": oracleexecutesql.Config{
|
||||
Name: "run_adhoc_query",
|
||||
Kind: "oracle-execute-sql",
|
||||
Source: "my-oracle-instance",
|
||||
Description: "Executes arbitrary SQL statements like INSERT or UPDATE.",
|
||||
AuthRequired: []string{"my-google-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "example without authRequired",
|
||||
in: `
|
||||
tools:
|
||||
run_simple_update:
|
||||
kind: oracle-execute-sql
|
||||
source: db-dev
|
||||
description: Runs a simple update operation.
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"run_simple_update": oracleexecutesql.Config{
|
||||
Name: "run_simple_update",
|
||||
Kind: "oracle-execute-sql",
|
||||
Source: "db-dev",
|
||||
Description: "Runs a simple update operation.",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
85
internal/tools/oracle/oraclesql/oraclesql_test.go
Normal file
85
internal/tools/oracle/oraclesql/oraclesql_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright © 2025, Oracle and/or its affiliates.
|
||||
package oraclesql_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
yaml "github.com/goccy/go-yaml"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/googleapis/genai-toolbox/internal/server"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/oracle/oraclesql"
|
||||
)
|
||||
|
||||
func TestParseFromYamlOracleSql(t *testing.T) {
|
||||
ctx, err := testutils.ContextWithNewLogger()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
tcs := []struct {
|
||||
desc string
|
||||
in string
|
||||
want server.ToolConfigs
|
||||
}{
|
||||
{
|
||||
desc: "basic example with statement and auth",
|
||||
in: `
|
||||
tools:
|
||||
get_user_by_id:
|
||||
kind: oracle-sql
|
||||
source: my-oracle-instance
|
||||
description: Retrieves user details by ID.
|
||||
statement: "SELECT id, name, email FROM users WHERE id = :1"
|
||||
authRequired:
|
||||
- my-google-auth-service
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"get_user_by_id": oraclesql.Config{
|
||||
Name: "get_user_by_id",
|
||||
Kind: "oracle-sql",
|
||||
Source: "my-oracle-instance",
|
||||
Description: "Retrieves user details by ID.",
|
||||
Statement: "SELECT id, name, email FROM users WHERE id = :1",
|
||||
AuthRequired: []string{"my-google-auth-service"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "example with parameters and template parameters",
|
||||
in: `
|
||||
tools:
|
||||
get_orders:
|
||||
kind: oracle-sql
|
||||
source: db-prod
|
||||
description: Gets orders for a customer with optional filtering.
|
||||
statement: "SELECT * FROM ${SCHEMA}.ORDERS WHERE customer_id = :customer_id AND status = :status"
|
||||
`,
|
||||
want: server.ToolConfigs{
|
||||
"get_orders": oraclesql.Config{
|
||||
Name: "get_orders",
|
||||
Kind: "oracle-sql",
|
||||
Source: "db-prod",
|
||||
Description: "Gets orders for a customer with optional filtering.",
|
||||
Statement: "SELECT * FROM ${SCHEMA}.ORDERS WHERE customer_id = :customer_id AND status = :status",
|
||||
AuthRequired: []string{},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := struct {
|
||||
Tools server.ToolConfigs `yaml:"tools"`
|
||||
}{}
|
||||
|
||||
err := yaml.UnmarshalContext(ctx, testutils.FormatYaml(tc.in), &got)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to unmarshal: %s", err)
|
||||
}
|
||||
if diff := cmp.Diff(tc.want, got.Tools); diff != "" {
|
||||
t.Fatalf("incorrect parse: diff %v", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
233
tests/cloudgda/cloud_gda_integration_test.go
Normal file
233
tests/cloudgda/cloud_gda_integration_test.go
Normal file
@@ -0,0 +1,233 @@
|
||||
// Copyright 2025 Google LLC
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cloudgda_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/googleapis/genai-toolbox/internal/server/mcp/jsonrpc"
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/tools/cloudgda"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
)
|
||||
|
||||
var (
|
||||
cloudGdaToolKind = "cloud-gemini-data-analytics-query"
|
||||
)
|
||||
|
||||
type cloudGdaTransport struct {
|
||||
transport http.RoundTripper
|
||||
url *url.URL
|
||||
}
|
||||
|
||||
func (t *cloudGdaTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if strings.HasPrefix(req.URL.String(), "https://geminidataanalytics.googleapis.com") {
|
||||
req.URL.Scheme = t.url.Scheme
|
||||
req.URL.Host = t.url.Host
|
||||
}
|
||||
return t.transport.RoundTrip(req)
|
||||
}
|
||||
|
||||
type masterHandler struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (h *masterHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.Contains(r.UserAgent(), "genai-toolbox/") {
|
||||
h.t.Errorf("User-Agent header not found")
|
||||
}
|
||||
|
||||
if r.Method != http.MethodPost {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
|
||||
// Verify URL structure
|
||||
// Expected: /v1beta/projects/{project}/locations/global:queryData
|
||||
if !strings.Contains(r.URL.Path, ":queryData") || !strings.Contains(r.URL.Path, "locations/global") {
|
||||
h.t.Errorf("unexpected URL path: %s", r.URL.Path)
|
||||
http.Error(w, "Not found", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody cloudgda.QueryDataRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
h.t.Fatalf("failed to decode request body: %v", err)
|
||||
}
|
||||
|
||||
if reqBody.Prompt == "" {
|
||||
http.Error(w, "missing prompt", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
response := map[string]any{
|
||||
"queryResult": "SELECT * FROM table;",
|
||||
"naturalLanguageAnswer": "Here is the answer.",
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloudGdaToolEndpoints(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
handler := &masterHandler{t: t}
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
|
||||
serverURL, err := url.Parse(server.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse server URL: %v", err)
|
||||
}
|
||||
|
||||
originalTransport := http.DefaultClient.Transport
|
||||
if originalTransport == nil {
|
||||
originalTransport = http.DefaultTransport
|
||||
}
|
||||
http.DefaultClient.Transport = &cloudGdaTransport{
|
||||
transport: originalTransport,
|
||||
url: serverURL,
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
http.DefaultClient.Transport = originalTransport
|
||||
})
|
||||
|
||||
var args []string
|
||||
toolsFile := getCloudGdaToolsConfig()
|
||||
cmd, cleanup, err := tests.StartCmd(ctx, toolsFile, args...)
|
||||
if err != nil {
|
||||
t.Fatalf("command initialization returned an error: %s", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
waitCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
out, err := testutils.WaitForString(waitCtx, regexp.MustCompile(`Server ready to serve`), cmd.Out)
|
||||
if err != nil {
|
||||
t.Logf("toolbox command logs: \n%s", out)
|
||||
t.Fatalf("toolbox didn't start successfully: %s", err)
|
||||
}
|
||||
|
||||
toolName := "cloud-gda-query"
|
||||
|
||||
// 1. RunToolGetTestByName
|
||||
expectedManifest := map[string]any{
|
||||
toolName: map[string]any{
|
||||
"description": "Test GDA Tool",
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"name": "prompt",
|
||||
"type": "string",
|
||||
"description": "The natural language question to ask.",
|
||||
"required": true,
|
||||
"authSources": []any{},
|
||||
},
|
||||
},
|
||||
"authRequired": []any{},
|
||||
},
|
||||
}
|
||||
tests.RunToolGetTestByName(t, toolName, expectedManifest)
|
||||
|
||||
// 2. RunToolInvokeParametersTest
|
||||
params := []byte(`{"prompt": "test question"}`)
|
||||
tests.RunToolInvokeParametersTest(t, toolName, params, "\"queryResult\":\"SELECT * FROM table;\"")
|
||||
|
||||
// 3. Manual MCP Tool Call Test
|
||||
// Initialize MCP session
|
||||
sessionId := tests.RunInitialize(t, "2024-11-05")
|
||||
|
||||
// Construct MCP Request
|
||||
mcpReq := jsonrpc.JSONRPCRequest{
|
||||
Jsonrpc: "2.0",
|
||||
Id: "test-mcp-call",
|
||||
Request: jsonrpc.Request{
|
||||
Method: "tools/call",
|
||||
},
|
||||
Params: map[string]any{
|
||||
"name": toolName,
|
||||
"arguments": map[string]any{
|
||||
"prompt": "test question",
|
||||
},
|
||||
},
|
||||
}
|
||||
reqBytes, _ := json.Marshal(mcpReq)
|
||||
|
||||
headers := map[string]string{}
|
||||
if sessionId != "" {
|
||||
headers["Mcp-Session-Id"] = sessionId
|
||||
}
|
||||
|
||||
// Send Request
|
||||
resp, respBody := tests.RunRequest(t, http.MethodPost, "http://127.0.0.1:5000/mcp", bytes.NewBuffer(reqBytes), headers)
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("MCP request failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
// Check Response
|
||||
respStr := string(respBody)
|
||||
if !strings.Contains(respStr, "SELECT * FROM table;") {
|
||||
t.Errorf("MCP response does not contain expected query result: %s", respStr)
|
||||
}
|
||||
}
|
||||
|
||||
func getCloudGdaToolsConfig() map[string]any {
|
||||
// Mocked responses and a dummy `projectId` are used in this integration
|
||||
// test due to limited project-specific allowlisting. API functionality is
|
||||
// verified via internal monitoring; this test specifically validates the
|
||||
// integration flow between the source and the tool.
|
||||
return map[string]any{
|
||||
"sources": map[string]any{
|
||||
"my-gda-source": map[string]any{
|
||||
"kind": "cloud-gemini-data-analytics",
|
||||
"projectId": "test-project",
|
||||
},
|
||||
},
|
||||
"tools": map[string]any{
|
||||
"cloud-gda-query": map[string]any{
|
||||
"kind": cloudGdaToolKind,
|
||||
"source": "my-gda-source",
|
||||
"description": "Test GDA Tool",
|
||||
"location": "us-central1",
|
||||
"context": map[string]any{
|
||||
"datasourceReferences": map[string]any{
|
||||
"spannerReference": map[string]any{
|
||||
"databaseReference": map[string]any{
|
||||
"projectId": "test-project",
|
||||
"instanceId": "test-instance",
|
||||
"databaseId": "test-db",
|
||||
"engine": "GOOGLE_SQL",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -30,6 +30,9 @@ import (
|
||||
"github.com/googleapis/genai-toolbox/internal/testutils"
|
||||
"github.com/googleapis/genai-toolbox/internal/util"
|
||||
"github.com/googleapis/genai-toolbox/tests"
|
||||
|
||||
"github.com/looker-open-source/sdk-codegen/go/rtl"
|
||||
v4 "github.com/looker-open-source/sdk-codegen/go/sdk/v4"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -139,11 +142,31 @@ func TestLooker(t *testing.T) {
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
},
|
||||
"make_look": map[string]any{
|
||||
"kind": "looker-make-look",
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
},
|
||||
"get_dashboards": map[string]any{
|
||||
"kind": "looker-get-dashboards",
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
},
|
||||
"make_dashboard": map[string]any{
|
||||
"kind": "looker-make-dashboard",
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
},
|
||||
"add_dashboard_filter": map[string]any{
|
||||
"kind": "looker-add-dashboard-filter",
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
},
|
||||
"add_dashboard_element": map[string]any{
|
||||
"kind": "looker-add-dashboard-element",
|
||||
"source": "my-instance",
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
},
|
||||
"conversational_analytics": map[string]any{
|
||||
"kind": "looker-conversational-analytics",
|
||||
"source": "my-instance",
|
||||
@@ -678,6 +701,116 @@ func TestLooker(t *testing.T) {
|
||||
},
|
||||
},
|
||||
)
|
||||
tests.RunToolGetTestByName(t, "make_look",
|
||||
map[string]any{
|
||||
"make_look": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"authRequired": []any{},
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The model containing the explore.",
|
||||
"name": "model",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The explore to be queried.",
|
||||
"name": "explore",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The fields to be retrieved.",
|
||||
"items": map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "A field to be returned in the query",
|
||||
"name": "field",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
"name": "fields",
|
||||
"required": true,
|
||||
"type": "array",
|
||||
},
|
||||
map[string]any{
|
||||
"additionalProperties": true,
|
||||
"authSources": []any{},
|
||||
"description": "The filters for the query",
|
||||
"name": "filters",
|
||||
"required": false,
|
||||
"type": "object",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The query pivots (must be included in fields as well).",
|
||||
"items": map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "A field to be used as a pivot in the query",
|
||||
"name": "pivot_field",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
"name": "pivots",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The sorts like \"field.id desc 0\".",
|
||||
"items": map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "A field to be used as a sort in the query",
|
||||
"name": "sort_field",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
"name": "sorts",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The row limit.",
|
||||
"name": "limit",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The query timezone.",
|
||||
"name": "tz",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The title of the Look",
|
||||
"name": "title",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The description of the Look",
|
||||
"name": "description",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"additionalProperties": true,
|
||||
"authSources": []any{},
|
||||
"description": "The visualization config for the query",
|
||||
"name": "vis_config",
|
||||
"required": false,
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
tests.RunToolGetTestByName(t, "get_dashboards",
|
||||
map[string]any{
|
||||
"get_dashboards": map[string]any{
|
||||
@@ -716,6 +849,235 @@ func TestLooker(t *testing.T) {
|
||||
},
|
||||
},
|
||||
)
|
||||
tests.RunToolGetTestByName(t, "make_dashboard",
|
||||
map[string]any{
|
||||
"make_dashboard": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"authRequired": []any{},
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The title of the Dashboard",
|
||||
"name": "title",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The description of the Dashboard",
|
||||
"name": "description",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
tests.RunToolGetTestByName(t, "add_dashboard_filter",
|
||||
map[string]any{
|
||||
"add_dashboard_filter": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"authRequired": []any{},
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The id of the dashboard where this filter will exist",
|
||||
"name": "dashboard_id",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The name of the Dashboard Filter",
|
||||
"name": "name",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The title of the Dashboard Filter",
|
||||
"name": "title",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The filter_type of the Dashboard Filter: date_filter, number_filter, string_filter, field_filter (default field_filter)",
|
||||
"name": "filter_type",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The default_value of the Dashboard Filter (optional)",
|
||||
"name": "default_value",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The model of a field type Dashboard Filter (required if type field)",
|
||||
"name": "model",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The explore of a field type Dashboard Filter (required if type field)",
|
||||
"name": "explore",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The dimension of a field type Dashboard Filter (required if type field)",
|
||||
"name": "dimension",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The Dashboard Filter should allow multiple values (default true)",
|
||||
"name": "allow_multiple_values",
|
||||
"required": false,
|
||||
"type": "boolean",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The Dashboard Filter is required to run dashboard (default false)",
|
||||
"name": "required",
|
||||
"required": false,
|
||||
"type": "boolean",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
tests.RunToolGetTestByName(t, "add_dashboard_element",
|
||||
map[string]any{
|
||||
"add_dashboard_element": map[string]any{
|
||||
"description": "Simple tool to test end to end functionality.",
|
||||
"authRequired": []any{},
|
||||
"parameters": []any{
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The model containing the explore.",
|
||||
"name": "model",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The explore to be queried.",
|
||||
"name": "explore",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The fields to be retrieved.",
|
||||
"items": map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "A field to be returned in the query",
|
||||
"name": "field",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
"name": "fields",
|
||||
"required": true,
|
||||
"type": "array",
|
||||
},
|
||||
map[string]any{
|
||||
"additionalProperties": true,
|
||||
"authSources": []any{},
|
||||
"description": "The filters for the query",
|
||||
"name": "filters",
|
||||
"required": false,
|
||||
"type": "object",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The query pivots (must be included in fields as well).",
|
||||
"items": map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "A field to be used as a pivot in the query",
|
||||
"name": "pivot_field",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
"name": "pivots",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The sorts like \"field.id desc 0\".",
|
||||
"items": map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "A field to be used as a sort in the query",
|
||||
"name": "sort_field",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
"name": "sorts",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The row limit.",
|
||||
"name": "limit",
|
||||
"required": false,
|
||||
"type": "integer",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The query timezone.",
|
||||
"name": "tz",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The id of the dashboard where this tile will exist",
|
||||
"name": "dashboard_id",
|
||||
"required": true,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": "The title of the Dashboard Element",
|
||||
"name": "title",
|
||||
"required": false,
|
||||
"type": "string",
|
||||
},
|
||||
map[string]any{
|
||||
"additionalProperties": true,
|
||||
"authSources": []any{},
|
||||
"description": "The visualization config for the query",
|
||||
"name": "vis_config",
|
||||
"required": false,
|
||||
"type": "object",
|
||||
},
|
||||
map[string]any{
|
||||
"authSources": []any{},
|
||||
"description": `An array of dashboard filters like [{"dashboard_filter_name": "name", "field": "view_name.field_name"}, ...]`,
|
||||
"items": map[string]any{
|
||||
"additionalProperties": true,
|
||||
"authSources": []any{},
|
||||
"description": `A dashboard filter like {"dashboard_filter_name": "name", "field": "view_name.field_name"}`,
|
||||
"name": "dashboard_filter",
|
||||
"required": false,
|
||||
"type": "object",
|
||||
},
|
||||
"name": "dashboard_filters",
|
||||
"required": false,
|
||||
"type": "array",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
tests.RunToolGetTestByName(t, "conversational_analytics",
|
||||
map[string]any{
|
||||
"conversational_analytics": map[string]any{
|
||||
@@ -1200,8 +1562,6 @@ func TestLooker(t *testing.T) {
|
||||
wantResult = "null"
|
||||
tests.RunToolInvokeParametersTest(t, "get_dashboards", []byte(`{"title": "FOO", "desc": "BAR"}`), wantResult)
|
||||
|
||||
runConversationalAnalytics(t, "system__activity", "content_usage")
|
||||
|
||||
wantResult = "\"Connection\":\"thelook\""
|
||||
tests.RunToolInvokeParametersTest(t, "health_pulse", []byte(`{"action": "check_db_connections"}`), wantResult)
|
||||
|
||||
@@ -1261,6 +1621,16 @@ func TestLooker(t *testing.T) {
|
||||
|
||||
wantResult = "/login/embed?t=" // testing for specific substring, since url is dynamic
|
||||
tests.RunToolInvokeParametersTest(t, "generate_embed_url", []byte(`{"type": "dashboards", "id": "1"}`), wantResult)
|
||||
|
||||
runConversationalAnalytics(t, "system__activity", "content_usage")
|
||||
|
||||
deleteLook := testMakeLook(t)
|
||||
defer deleteLook()
|
||||
|
||||
dashboardId, deleteDashboard := testMakeDashboard(t)
|
||||
defer deleteDashboard()
|
||||
testAddDashboardFilter(t, dashboardId)
|
||||
testAddDashboardElement(t, dashboardId)
|
||||
}
|
||||
|
||||
func runConversationalAnalytics(t *testing.T, modelName, exploreName string) {
|
||||
@@ -1325,3 +1695,122 @@ func runConversationalAnalytics(t *testing.T, modelName, exploreName string) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func newLookerTestSDK(t *testing.T) *v4.LookerSDK {
|
||||
t.Helper()
|
||||
cfg := rtl.ApiSettings{
|
||||
BaseUrl: LookerBaseUrl,
|
||||
ApiVersion: "4.0",
|
||||
VerifySsl: LookerVerifySsl == "true",
|
||||
Timeout: 120,
|
||||
ClientId: LookerClientId,
|
||||
ClientSecret: LookerClientSecret,
|
||||
}
|
||||
return v4.NewLookerSDK(rtl.NewAuthSession(cfg))
|
||||
}
|
||||
|
||||
func testMakeLook(t *testing.T) func() {
|
||||
var id string
|
||||
t.Run("TestMakeLook", func(t *testing.T) {
|
||||
reqBody := []byte(`{"model": "system__activity", "explore": "look", "fields": ["look.count"], "title": "TestLook"}`)
|
||||
|
||||
url := "http://127.0.0.1:5000/api/tool/make_look/invoke"
|
||||
resp, bodyBytes := tests.RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqBody), nil)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, 200, string(bodyBytes))
|
||||
}
|
||||
|
||||
var respBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &respBody); err != nil {
|
||||
t.Fatalf("error parsing response body: %v", err)
|
||||
}
|
||||
|
||||
result := respBody["result"].(string)
|
||||
if err := json.Unmarshal([]byte(result), &respBody); err != nil {
|
||||
t.Fatalf("error parsing result body: %v", err)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if id, ok = respBody["id"].(string); !ok || id == "" {
|
||||
t.Fatalf("didn't get TestLook id, got %s", string(bodyBytes))
|
||||
}
|
||||
})
|
||||
|
||||
return func() {
|
||||
sdk := newLookerTestSDK(t)
|
||||
|
||||
if _, err := sdk.DeleteLook(id, nil); err != nil {
|
||||
t.Fatalf("error deleting look: %v", err)
|
||||
}
|
||||
t.Logf("deleted Look %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
func testAddDashboardFilter(t *testing.T, dashboardId string) {
|
||||
t.Run("TestAddDashboardFilter", func(t *testing.T) {
|
||||
reqBody := []byte(fmt.Sprintf(`{"dashboard_id": "%s", "model": "system__activity", "explore": "look", "dimension": "look.created_year", "name": "test_filter", "title": "TestDashboardFilter"}`, dashboardId))
|
||||
|
||||
url := "http://127.0.0.1:5000/api/tool/add_dashboard_filter/invoke"
|
||||
resp, bodyBytes := tests.RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqBody), nil)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, 200, string(bodyBytes))
|
||||
}
|
||||
|
||||
t.Logf("got %s", string(bodyBytes))
|
||||
})
|
||||
}
|
||||
|
||||
func testAddDashboardElement(t *testing.T, dashboardId string) {
|
||||
t.Run("TestAddDashboardElement", func(t *testing.T) {
|
||||
reqBody := []byte(fmt.Sprintf(`{"dashboard_id": "%s", "model": "system__activity", "explore": "look", "fields": ["look.count"], "title": "TestDashboardElement"}`, dashboardId))
|
||||
|
||||
url := "http://127.0.0.1:5000/api/tool/add_dashboard_element/invoke"
|
||||
resp, bodyBytes := tests.RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqBody), nil)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, 200, string(bodyBytes))
|
||||
}
|
||||
|
||||
t.Logf("got %s", string(bodyBytes))
|
||||
})
|
||||
}
|
||||
|
||||
func testMakeDashboard(t *testing.T) (string, func()) {
|
||||
var id string
|
||||
t.Run("TestMakeDashboard", func(t *testing.T) {
|
||||
reqBody := []byte(`{"title": "TestDashboard"}`)
|
||||
|
||||
url := "http://127.0.0.1:5000/api/tool/make_dashboard/invoke"
|
||||
resp, bodyBytes := tests.RunRequest(t, http.MethodPost, url, bytes.NewBuffer(reqBody), nil)
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
t.Fatalf("unexpected status code: got %d, want %d. Body: %s", resp.StatusCode, 200, string(bodyBytes))
|
||||
}
|
||||
|
||||
var respBody map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &respBody); err != nil {
|
||||
t.Fatalf("error parsing response body: %v", err)
|
||||
}
|
||||
|
||||
result := respBody["result"].(string)
|
||||
if err := json.Unmarshal([]byte(result), &respBody); err != nil {
|
||||
t.Fatalf("error parsing result body: %v", err)
|
||||
}
|
||||
|
||||
var ok bool
|
||||
if id, ok = respBody["id"].(string); !ok || id == "" {
|
||||
t.Fatalf("didn't get TestDashboard id, got %s", string(bodyBytes))
|
||||
}
|
||||
})
|
||||
|
||||
return id, func() {
|
||||
sdk := newLookerTestSDK(t)
|
||||
|
||||
if _, err := sdk.DeleteDashboard(id, nil); err != nil {
|
||||
t.Fatalf("error deleting dashboard: %v", err)
|
||||
}
|
||||
t.Logf("deleted Dashboard %s", id)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ func getOracleVars(t *testing.T) map[string]any {
|
||||
return map[string]any{
|
||||
"kind": OracleSourceKind,
|
||||
"connectionString": OracleConnStr,
|
||||
"useOCI": true,
|
||||
"user": OracleUser,
|
||||
"password": OraclePass,
|
||||
}
|
||||
@@ -50,9 +51,11 @@ func getOracleVars(t *testing.T) map[string]any {
|
||||
|
||||
// Copied over from oracle.go
|
||||
func initOracleConnection(ctx context.Context, user, pass, connStr string) (*sql.DB, error) {
|
||||
fullConnStr := fmt.Sprintf("oracle://%s:%s@%s", user, pass, connStr)
|
||||
// Build the full Oracle connection string for godror driver
|
||||
fullConnStr := fmt.Sprintf(`user="%s" password="%s" connectString="%s"`,
|
||||
user, pass, connStr)
|
||||
|
||||
db, err := sql.Open("oracle", fullConnStr)
|
||||
db, err := sql.Open("godror", fullConnStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to open Oracle connection: %w", err)
|
||||
}
|
||||
@@ -116,13 +119,15 @@ func TestOracleSimpleToolEndpoints(t *testing.T) {
|
||||
|
||||
// Get configs for tests
|
||||
select1Want := "[{\"1\":1}]"
|
||||
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: ORA-00900: invalid SQL statement\n error occur at position: 0"}],"isError":true}}`
|
||||
mcpMyFailToolWant := `{"jsonrpc":"2.0","id":"invoke-fail-tool","result":{"content":[{"type":"text","text":"unable to execute query: dpiStmt_execute: ORA-00900: invalid SQL statement"}],"isError":true}}`
|
||||
createTableStatement := `"CREATE TABLE t (id NUMBER GENERATED AS IDENTITY PRIMARY KEY, name VARCHAR2(255))"`
|
||||
mcpSelect1Want := `{"jsonrpc":"2.0","id":"invoke my-auth-required-tool","result":{"content":[{"type":"text","text":"{\"1\":1}"}]}}`
|
||||
|
||||
// Run tests
|
||||
tests.RunToolGetTest(t)
|
||||
tests.RunToolInvokeTest(t, select1Want,
|
||||
tests.DisableOptionalNullParamTest(),
|
||||
tests.WithMyToolById4Want("[{\"id\":4,\"name\":\"\"}]"),
|
||||
tests.DisableArrayTest(),
|
||||
)
|
||||
tests.RunMCPToolCallMethod(t, mcpMyFailToolWant, mcpSelect1Want)
|
||||
|
||||
@@ -2401,10 +2401,10 @@ func RunPostgresListPgSettingsTest(t *testing.T, ctx context.Context, pool *pgxp
|
||||
// RunPostgresDatabaseStatsTest tests the database_stats tool by comparing API results
|
||||
// against a direct query to the database.
|
||||
func RunPostgresListDatabaseStatsTest(t *testing.T, ctx context.Context, pool *pgxpool.Pool) {
|
||||
dbName1 := "test_db_stats_1"
|
||||
dbOwner1 := "test_user1"
|
||||
dbName2 := "test_db_stats_2"
|
||||
dbOwner2 := "test_user2"
|
||||
dbName1 := "test_db_stats_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
dbOwner1 := "test_user_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
dbName2 := "test_db_stats_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
dbOwner2 := "test_user_" + strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
|
||||
cleanup1 := setUpDatabase(t, ctx, pool, dbName1, dbOwner1)
|
||||
defer cleanup1()
|
||||
|
||||
Reference in New Issue
Block a user