This commit is contained in:
Twisha Bansal
2026-01-28 18:20:40 +05:30
parent f169874e53
commit 49cb2f39f7
2 changed files with 26 additions and 24 deletions

View File

@@ -1,21 +1,22 @@
# Copyright 2026 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
from pathlib import Path
import asyncio
import os
from pathlib import Path
import pytest
ORCH_NAME = os.environ.get("ORCH_NAME")
module_path = f"python.{ORCH_NAME}.agent"
@@ -29,7 +30,7 @@ def golden_keywords():
if not golden_file_path.exists():
pytest.fail(f"Golden file not found: {golden_file_path}")
try:
with open(golden_file_path, 'r') as f:
with open(golden_file_path, "r") as f:
return [line.strip() for line in f.readlines() if line.strip()]
except Exception as e:
pytest.fail(f"Could not read golden.txt: {e}")

View File

@@ -1,11 +1,10 @@
import asyncio
from toolbox_langchain import ToolboxClient
from langchain_google_vertexai import ChatVertexAI
from langchain_core.messages import ToolMessage, messages_to_dict
from langchain.agents import create_agent
from langchain.agents.middleware import (
wrap_tool_call,
)
from langchain.agents.middleware import wrap_tool_call
from langchain_core.messages import ToolMessage, messages_to_dict
from langchain_google_vertexai import ChatVertexAI
from toolbox_langchain import ToolboxClient
system_prompt = """
You're a helpful hotel assistant. You handle hotel searching, booking and
@@ -17,6 +16,7 @@ system_prompt = """
Don't ask for confirmations from the user.
"""
# Pre processing
@wrap_tool_call
async def enforce_business_rules(request, handler):
@@ -32,14 +32,15 @@ async def enforce_business_rules(request, handler):
if name == "book-hotel":
if "duration_days" in args and int(args["duration_days"]) > 14:
print("BLOCKED: Stay too long")
return ToolMessage(
content="Error: Maximum stay duration is 14 days.",
tool_call_id=tool_call["id"]
)
print("BLOCKED: Stay too long")
return ToolMessage(
content="Error: Maximum stay duration is 14 days.",
tool_call_id=tool_call["id"],
)
return await handler(request)
# Post processing
@wrap_tool_call
async def enrich_response(request, handler):
@@ -69,19 +70,19 @@ async def main():
system_prompt=system_prompt,
model=model,
tools=tools,
middleware=[
enforce_business_rules,
enrich_response
],
middleware=[enforce_business_rules, enrich_response],
)
user_input = "Book hotel with id 3."
response = await agent.ainvoke({"messages": [{"role": "user", "content": user_input}]})
response = await agent.ainvoke(
{"messages": [{"role": "user", "content": user_input}]}
)
print("-" * 50)
print("Final Client Response:")
last_ai_msg = response["messages"][-1].content
print(f"AI: {last_ai_msg}")
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())