""" 用于执行支持工具的 LangChain 代理的编排器。 """ from typing import List, Dict, Any, Optional from datetime import datetime from sqlalchemy.orm import Session from langchain.agents import AgentExecutor, create_openai_tools_agent from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from app.db.models import ToolCall from app.services.model_manager import ModelManager from app.services.kb_manager import KnowledgeBaseManager from app.tools.retriever import KnowledgeBaseRetriever from app.tools.calculator import CalculatorTool from app.config import Settings from app.utils.logger import get_logger logger = get_logger(__name__) class AgentOrchestrator: """ 使用自定义工具编排 LangChain 代理。 """ def __init__( self, db_session: Session, model_manager: ModelManager, kb_manager: KnowledgeBaseManager, settings: Settings, ): """ 初始化代理编排器。 Args: db_session: 数据库会话 model_manager: 模型管理器实例 kb_manager: 知识库管理器实例 settings: 应用设置 """ self.db = db_session self.model_manager = model_manager self.kb_manager = kb_manager self.settings = settings def _create_tools(self, kb_id: Optional[int] = None) -> List: """ 为代理创建工具。 Args: kb_id: 可选的知识库 ID,用于检索工具 Returns: List: LangChain 工具列表 """ tools = [] # 添加计算器工具(始终可用) tools.append(CalculatorTool()) # 如果指定了知识库,则添加检索工具 if kb_id is not None: retriever = KnowledgeBaseRetriever(kb_manager=self.kb_manager) # 预绑定 kb_id 到工具 tools.append(retriever) logger.info("tools_created", tool_count=len(tools), kb_id=kb_id) return tools def _create_agent_prompt(self) -> ChatPromptTemplate: """ 创建代理提示模板。 Returns: ChatPromptTemplate: 代理提示模板 """ system_message = """You are a helpful AI assistant with access to various tools. Available tools: - calculator: For mathematical calculations - knowledge_base_search: For searching stored documents (if knowledge base is provided) Always think step by step and use tools when appropriate. When using the knowledge_base_search tool, you must provide the kb_id parameter. """ prompt = ChatPromptTemplate.from_messages([ ("system", system_message), MessagesPlaceholder(variable_name="chat_history", optional=True), ("human", "{input}"), MessagesPlaceholder(variable_name="agent_scratchpad"), ]) return prompt async def execute_agent( self, task: str, agent_id: str, kb_id: Optional[int] = None, llm_name: Optional[str] = None, ) -> Dict[str, Any]: """ 使用工具执行代理任务。 Args: task: 任务描述 agent_id: 唯一的代理执行 ID kb_id: 可选的知识库 ID llm_name: 可选的 LLM 模型名称 Returns: Dict: 包含输出和工具调用的执行结果 """ logger.info( "executing_agent", agent_id=agent_id, task=task[:100], kb_id=kb_id, ) try: # 获取 LLM llm = self.model_manager.get_llm(llm_name) # 创建工具 tools = self._create_tools(kb_id=kb_id) # 创建代理提示 prompt = self._create_agent_prompt() # 创建代理 agent = create_openai_tools_agent(llm, tools, prompt) # 创建执行器 agent_executor = AgentExecutor( agent=agent, tools=tools, verbose=True, max_iterations=10, handle_parsing_errors=True, ) # 执行代理 result = await agent_executor.ainvoke({ "input": task, "kb_id": kb_id, # 为检索工具传递 kb_id }) output = result.get("output", "") # 将工具调用记录到数据库 self._log_tool_calls(agent_id, result) logger.info("agent_execution_complete", agent_id=agent_id) return { "agent_id": agent_id, "output": output, "success": True, } except Exception as e: logger.error("agent_execution_failed", agent_id=agent_id, error=str(e)) return { "agent_id": agent_id, "output": f"Agent execution failed: {str(e)}", "success": False, "error": str(e), } def _log_tool_calls(self, agent_id: str, result: Dict[str, Any]) -> None: """ 将工具调用记录到数据库。 Args: agent_id: 代理执行 ID result: 代理执行结果 """ try: # 提取中间步骤(工具调用) intermediate_steps = result.get("intermediate_steps", []) for action, observation in intermediate_steps: # 创建工具调用记录 tool_call = ToolCall( agent_id=agent_id, tool_name=action.tool, call_input={"tool_input": action.tool_input}, call_output={"observation": str(observation)}, created_at=datetime.utcnow(), ) self.db.add(tool_call) self.db.commit() logger.info( "tool_calls_logged", agent_id=agent_id, call_count=len(intermediate_steps), ) except Exception as e: logger.error("tool_call_logging_failed", agent_id=agent_id, error=str(e)) self.db.rollback() def get_tool_calls( self, agent_id: str, limit: int = 50, offset: int = 0, ) -> List[ToolCall]: """ 获取代理执行的工具调用。 Args: agent_id: 代理执行 ID limit: 最大记录数 offset: 分页偏移量 Returns: List[ToolCall]: 工具调用记录列表 """ tool_calls = ( self.db.query(ToolCall) .filter(ToolCall.agent_id == agent_id) .order_by(ToolCall.created_at.asc()) .offset(offset) .limit(limit) .all() ) return tool_calls