adk working file

This commit is contained in:
Twisha Bansal
2026-02-09 11:13:09 +05:30
parent 632dd10180
commit 0b23fc950f

View File

@@ -0,0 +1,139 @@
import asyncio
from datetime import datetime
from typing import Any, Dict, Optional
from copy import deepcopy
from google.adk import Agent
from google.adk.apps import App
from google.adk.runners import Runner
from google.adk.sessions.in_memory_session_service import InMemorySessionService
from google.adk.tools.tool_context import ToolContext
from google.genai import types
from toolbox_adk import CredentialStrategy, ToolboxToolset, ToolboxTool
SYSTEM_PROMPT = """
You're a helpful hotel assistant. You handle hotel searching, booking and
cancellations. When the user searches for a hotel, mention it's name, id,
location and price tier. Always mention hotel ids while performing any
searches. This is very important for any operations. For any bookings or
cancellations, please provide the appropriate confirmation. Be sure to
update checkin or checkout dates if mentioned by the user.
Don't ask for confirmations from the user.
"""
# Pre processing
async def before_tool_callback(
tool: ToolboxTool, args: Dict[str, Any], tool_context: ToolContext
) -> Dict[str, Any]:
"""
Callback fired before a tool is executed.
Enforces business logic: Max stay duration is 14 days.
"""
tool_name = tool.name
print(f"POLICY CHECK: Intercepting '{tool_name}'")
if tool_name == "update-hotel" or (
"checkin_date" in args and "checkout_date" in args
):
start = datetime.fromisoformat(args["checkin_date"])
end = datetime.fromisoformat(args["checkout_date"])
duration = (end - start).days
if duration > 14:
print("BLOCKED: Stay too long")
return {"result": "Error: Maximum stay duration is 14 days."}
return None
# Post processing
async def after_tool_callback(
tool: ToolboxTool,
args: Dict[str, Any],
tool_context: ToolContext,
tool_response: Any,
):
"""
Callback fired after a tool execution.
Enriches response for successful bookings.
"""
if isinstance(tool_response, dict):
result = tool_response.get("result", "")
elif isinstance(tool_response, str):
result = tool_response
else:
return None
tool_name = tool.name
if isinstance(result, str) and "Error" not in result:
is_booking_tool = tool_name == "book-hotel"
if is_booking_tool:
loyalty_bonus = 500
enriched_result = f"Booking Confirmed!\n You earned {loyalty_bonus} Loyalty Points with this stay.\n\nSystem Details: {result}"
if isinstance(tool_response, dict):
modified_response = deepcopy(tool_response)
modified_response["result"] = enriched_result
return modified_response
else:
return enriched_result
return None
async def run_chat_turn(
runner: Runner, session_id: str, user_id: str, message_text: str
):
"""Executes a single chat turn and prints the interaction."""
print(f"\nUSER: '{message_text}'")
response_text = ""
async for event in runner.run_async(
user_id=user_id,
session_id=session_id,
new_message=types.Content(role="user", parts=[types.Part(text=message_text)]),
):
if event.content and event.content.parts:
for part in event.content.parts:
if part.text:
response_text += part.text
print(f"AI: {response_text}")
async def main():
toolset = ToolboxToolset(
server_url="http://127.0.0.1:5000",
toolset_name="my-toolset",
credentials=CredentialStrategy.toolbox_identity(),
)
tools = await toolset.get_tools()
root_agent = Agent(
name="root_agent",
model="gemini-2.5-flash",
instruction=SYSTEM_PROMPT,
tools=tools,
before_tool_callback=before_tool_callback,
after_tool_callback=after_tool_callback,
)
app = App(root_agent=root_agent, name="my_agent")
runner = Runner(app=app, session_service=InMemorySessionService())
session_id = "test-session"
user_id = "test-user"
await runner.session_service.create_session(
app_name=app.name, user_id=user_id, session_id=session_id
)
# First turn: Successful booking
await run_chat_turn(runner, session_id, user_id, "Book hotel with id 3.")
print("-" * 50)
# Second turn: Policy violation (stay > 14 days)
await run_chat_turn(
runner,
session_id,
user_id,
"Book a hotel with id 5 with checkin date 2025-01-18 and checkout date 2025-02-10",
)
await toolset.close()
if __name__ == "__main__":
asyncio.run(main())