234 lines
6.7 KiB
Python
234 lines
6.7 KiB
Python
"""
|
||
用于执行支持工具的 LangChain 代理的编排器。
|
||
"""
|
||
from typing import List, Dict, Any, Optional
|
||
from datetime import datetime
|
||
from sqlalchemy.orm import Session
|
||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||
|
||
from app.db.models import ToolCall
|
||
from app.services.model_manager import ModelManager
|
||
from app.services.kb_manager import KnowledgeBaseManager
|
||
from app.tools.retriever import KnowledgeBaseRetriever
|
||
from app.tools.calculator import CalculatorTool
|
||
from app.config import Settings
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class AgentOrchestrator:
|
||
"""
|
||
使用自定义工具编排 LangChain 代理。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
db_session: Session,
|
||
model_manager: ModelManager,
|
||
kb_manager: KnowledgeBaseManager,
|
||
settings: Settings,
|
||
):
|
||
"""
|
||
初始化代理编排器。
|
||
|
||
Args:
|
||
db_session: 数据库会话
|
||
model_manager: 模型管理器实例
|
||
kb_manager: 知识库管理器实例
|
||
settings: 应用设置
|
||
"""
|
||
self.db = db_session
|
||
self.model_manager = model_manager
|
||
self.kb_manager = kb_manager
|
||
self.settings = settings
|
||
|
||
def _create_tools(self, kb_id: Optional[int] = None) -> List:
|
||
"""
|
||
为代理创建工具。
|
||
|
||
Args:
|
||
kb_id: 可选的知识库 ID,用于检索工具
|
||
|
||
Returns:
|
||
List: LangChain 工具列表
|
||
"""
|
||
tools = []
|
||
|
||
# 添加计算器工具(始终可用)
|
||
tools.append(CalculatorTool())
|
||
|
||
# 如果指定了知识库,则添加检索工具
|
||
if kb_id is not None:
|
||
retriever = KnowledgeBaseRetriever(kb_manager=self.kb_manager)
|
||
# 预绑定 kb_id 到工具
|
||
tools.append(retriever)
|
||
|
||
logger.info("tools_created", tool_count=len(tools), kb_id=kb_id)
|
||
return tools
|
||
|
||
def _create_agent_prompt(self) -> ChatPromptTemplate:
|
||
"""
|
||
创建代理提示模板。
|
||
|
||
Returns:
|
||
ChatPromptTemplate: 代理提示模板
|
||
"""
|
||
system_message = """You are a helpful AI assistant with access to various tools.
|
||
|
||
Available tools:
|
||
- calculator: For mathematical calculations
|
||
- knowledge_base_search: For searching stored documents (if knowledge base is provided)
|
||
|
||
Always think step by step and use tools when appropriate.
|
||
When using the knowledge_base_search tool, you must provide the kb_id parameter.
|
||
"""
|
||
|
||
prompt = ChatPromptTemplate.from_messages([
|
||
("system", system_message),
|
||
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
||
("human", "{input}"),
|
||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||
])
|
||
|
||
return prompt
|
||
|
||
async def execute_agent(
|
||
self,
|
||
task: str,
|
||
agent_id: str,
|
||
kb_id: Optional[int] = None,
|
||
llm_name: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
使用工具执行代理任务。
|
||
|
||
Args:
|
||
task: 任务描述
|
||
agent_id: 唯一的代理执行 ID
|
||
kb_id: 可选的知识库 ID
|
||
llm_name: 可选的 LLM 模型名称
|
||
|
||
Returns:
|
||
Dict: 包含输出和工具调用的执行结果
|
||
"""
|
||
logger.info(
|
||
"executing_agent",
|
||
agent_id=agent_id,
|
||
task=task[:100],
|
||
kb_id=kb_id,
|
||
)
|
||
|
||
try:
|
||
# 获取 LLM
|
||
llm = self.model_manager.get_llm(llm_name)
|
||
|
||
# 创建工具
|
||
tools = self._create_tools(kb_id=kb_id)
|
||
|
||
# 创建代理提示
|
||
prompt = self._create_agent_prompt()
|
||
|
||
# 创建代理
|
||
agent = create_openai_tools_agent(llm, tools, prompt)
|
||
|
||
# 创建执行器
|
||
agent_executor = AgentExecutor(
|
||
agent=agent,
|
||
tools=tools,
|
||
verbose=True,
|
||
max_iterations=10,
|
||
handle_parsing_errors=True,
|
||
)
|
||
|
||
# 执行代理
|
||
result = await agent_executor.ainvoke({
|
||
"input": task,
|
||
"kb_id": kb_id, # 为检索工具传递 kb_id
|
||
})
|
||
|
||
output = result.get("output", "")
|
||
|
||
# 将工具调用记录到数据库
|
||
self._log_tool_calls(agent_id, result)
|
||
|
||
logger.info("agent_execution_complete", agent_id=agent_id)
|
||
|
||
return {
|
||
"agent_id": agent_id,
|
||
"output": output,
|
||
"success": True,
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error("agent_execution_failed", agent_id=agent_id, error=str(e))
|
||
return {
|
||
"agent_id": agent_id,
|
||
"output": f"Agent execution failed: {str(e)}",
|
||
"success": False,
|
||
"error": str(e),
|
||
}
|
||
|
||
def _log_tool_calls(self, agent_id: str, result: Dict[str, Any]) -> None:
|
||
"""
|
||
将工具调用记录到数据库。
|
||
|
||
Args:
|
||
agent_id: 代理执行 ID
|
||
result: 代理执行结果
|
||
"""
|
||
try:
|
||
# 提取中间步骤(工具调用)
|
||
intermediate_steps = result.get("intermediate_steps", [])
|
||
|
||
for action, observation in intermediate_steps:
|
||
# 创建工具调用记录
|
||
tool_call = ToolCall(
|
||
agent_id=agent_id,
|
||
tool_name=action.tool,
|
||
call_input={"tool_input": action.tool_input},
|
||
call_output={"observation": str(observation)},
|
||
created_at=datetime.utcnow(),
|
||
)
|
||
self.db.add(tool_call)
|
||
|
||
self.db.commit()
|
||
logger.info(
|
||
"tool_calls_logged",
|
||
agent_id=agent_id,
|
||
call_count=len(intermediate_steps),
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error("tool_call_logging_failed", agent_id=agent_id, error=str(e))
|
||
self.db.rollback()
|
||
|
||
def get_tool_calls(
|
||
self,
|
||
agent_id: str,
|
||
limit: int = 50,
|
||
offset: int = 0,
|
||
) -> List[ToolCall]:
|
||
"""
|
||
获取代理执行的工具调用。
|
||
|
||
Args:
|
||
agent_id: 代理执行 ID
|
||
limit: 最大记录数
|
||
offset: 分页偏移量
|
||
|
||
Returns:
|
||
List[ToolCall]: 工具调用记录列表
|
||
"""
|
||
tool_calls = (
|
||
self.db.query(ToolCall)
|
||
.filter(ToolCall.agent_id == agent_id)
|
||
.order_by(ToolCall.created_at.asc())
|
||
.offset(offset)
|
||
.limit(limit)
|
||
.all()
|
||
)
|
||
|
||
return tool_calls
|