langchain-learning-kit/app/services/agent_orchestrator.py

234 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
用于执行支持工具的 LangChain 代理的编排器。
"""
from typing import List, Dict, Any, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from app.db.models import ToolCall
from app.services.model_manager import ModelManager
from app.services.kb_manager import KnowledgeBaseManager
from app.tools.retriever import KnowledgeBaseRetriever
from app.tools.calculator import CalculatorTool
from app.config import Settings
from app.utils.logger import get_logger
logger = get_logger(__name__)
class AgentOrchestrator:
"""
使用自定义工具编排 LangChain 代理。
"""
def __init__(
self,
db_session: Session,
model_manager: ModelManager,
kb_manager: KnowledgeBaseManager,
settings: Settings,
):
"""
初始化代理编排器。
Args:
db_session: 数据库会话
model_manager: 模型管理器实例
kb_manager: 知识库管理器实例
settings: 应用设置
"""
self.db = db_session
self.model_manager = model_manager
self.kb_manager = kb_manager
self.settings = settings
def _create_tools(self, kb_id: Optional[int] = None) -> List:
"""
为代理创建工具。
Args:
kb_id: 可选的知识库 ID用于检索工具
Returns:
List: LangChain 工具列表
"""
tools = []
# 添加计算器工具(始终可用)
tools.append(CalculatorTool())
# 如果指定了知识库,则添加检索工具
if kb_id is not None:
retriever = KnowledgeBaseRetriever(kb_manager=self.kb_manager)
# 预绑定 kb_id 到工具
tools.append(retriever)
logger.info("tools_created", tool_count=len(tools), kb_id=kb_id)
return tools
def _create_agent_prompt(self) -> ChatPromptTemplate:
"""
创建代理提示模板。
Returns:
ChatPromptTemplate: 代理提示模板
"""
system_message = """You are a helpful AI assistant with access to various tools.
Available tools:
- calculator: For mathematical calculations
- knowledge_base_search: For searching stored documents (if knowledge base is provided)
Always think step by step and use tools when appropriate.
When using the knowledge_base_search tool, you must provide the kb_id parameter.
"""
prompt = ChatPromptTemplate.from_messages([
("system", system_message),
MessagesPlaceholder(variable_name="chat_history", optional=True),
("human", "{input}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
])
return prompt
async def execute_agent(
self,
task: str,
agent_id: str,
kb_id: Optional[int] = None,
llm_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
使用工具执行代理任务。
Args:
task: 任务描述
agent_id: 唯一的代理执行 ID
kb_id: 可选的知识库 ID
llm_name: 可选的 LLM 模型名称
Returns:
Dict: 包含输出和工具调用的执行结果
"""
logger.info(
"executing_agent",
agent_id=agent_id,
task=task[:100],
kb_id=kb_id,
)
try:
# 获取 LLM
llm = self.model_manager.get_llm(llm_name)
# 创建工具
tools = self._create_tools(kb_id=kb_id)
# 创建代理提示
prompt = self._create_agent_prompt()
# 创建代理
agent = create_openai_tools_agent(llm, tools, prompt)
# 创建执行器
agent_executor = AgentExecutor(
agent=agent,
tools=tools,
verbose=True,
max_iterations=10,
handle_parsing_errors=True,
)
# 执行代理
result = await agent_executor.ainvoke({
"input": task,
"kb_id": kb_id, # 为检索工具传递 kb_id
})
output = result.get("output", "")
# 将工具调用记录到数据库
self._log_tool_calls(agent_id, result)
logger.info("agent_execution_complete", agent_id=agent_id)
return {
"agent_id": agent_id,
"output": output,
"success": True,
}
except Exception as e:
logger.error("agent_execution_failed", agent_id=agent_id, error=str(e))
return {
"agent_id": agent_id,
"output": f"Agent execution failed: {str(e)}",
"success": False,
"error": str(e),
}
def _log_tool_calls(self, agent_id: str, result: Dict[str, Any]) -> None:
"""
将工具调用记录到数据库。
Args:
agent_id: 代理执行 ID
result: 代理执行结果
"""
try:
# 提取中间步骤(工具调用)
intermediate_steps = result.get("intermediate_steps", [])
for action, observation in intermediate_steps:
# 创建工具调用记录
tool_call = ToolCall(
agent_id=agent_id,
tool_name=action.tool,
call_input={"tool_input": action.tool_input},
call_output={"observation": str(observation)},
created_at=datetime.utcnow(),
)
self.db.add(tool_call)
self.db.commit()
logger.info(
"tool_calls_logged",
agent_id=agent_id,
call_count=len(intermediate_steps),
)
except Exception as e:
logger.error("tool_call_logging_failed", agent_id=agent_id, error=str(e))
self.db.rollback()
def get_tool_calls(
self,
agent_id: str,
limit: int = 50,
offset: int = 0,
) -> List[ToolCall]:
"""
获取代理执行的工具调用。
Args:
agent_id: 代理执行 ID
limit: 最大记录数
offset: 分页偏移量
Returns:
List[ToolCall]: 工具调用记录列表
"""
tool_calls = (
self.db.query(ToolCall)
.filter(ToolCall.agent_id == agent_id)
.order_by(ToolCall.created_at.asc())
.offset(offset)
.limit(limit)
.all()
)
return tool_calls