2025-10-02 17:22:19 +08:00
|
|
|
|
"""
|
|
|
|
|
|
用于执行支持工具的 LangChain 代理的编排器。
|
|
|
|
|
|
"""
|
|
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
|
|
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
2025-10-02 18:10:53 +08:00
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
2025-10-02 17:22:19 +08:00
|
|
|
|
|
|
|
|
|
|
from app.db.models import ToolCall
|
|
|
|
|
|
from app.services.model_manager import ModelManager
|
|
|
|
|
|
from app.services.kb_manager import KnowledgeBaseManager
|
|
|
|
|
|
from app.tools.retriever import KnowledgeBaseRetriever
|
|
|
|
|
|
from app.tools.calculator import CalculatorTool
|
|
|
|
|
|
from app.config import Settings
|
|
|
|
|
|
from app.utils.logger import get_logger
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AgentOrchestrator:
|
|
|
|
|
|
"""
|
|
|
|
|
|
使用自定义工具编排 LangChain 代理。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
db_session: Session,
|
|
|
|
|
|
model_manager: ModelManager,
|
|
|
|
|
|
kb_manager: KnowledgeBaseManager,
|
|
|
|
|
|
settings: Settings,
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化代理编排器。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
db_session: 数据库会话
|
|
|
|
|
|
model_manager: 模型管理器实例
|
|
|
|
|
|
kb_manager: 知识库管理器实例
|
|
|
|
|
|
settings: 应用设置
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.db = db_session
|
|
|
|
|
|
self.model_manager = model_manager
|
|
|
|
|
|
self.kb_manager = kb_manager
|
|
|
|
|
|
self.settings = settings
|
|
|
|
|
|
|
|
|
|
|
|
def _create_tools(self, kb_id: Optional[int] = None) -> List:
|
|
|
|
|
|
"""
|
|
|
|
|
|
为代理创建工具。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
kb_id: 可选的知识库 ID,用于检索工具
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
List: LangChain 工具列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
tools = []
|
|
|
|
|
|
|
|
|
|
|
|
# 添加计算器工具(始终可用)
|
|
|
|
|
|
tools.append(CalculatorTool())
|
|
|
|
|
|
|
|
|
|
|
|
# 如果指定了知识库,则添加检索工具
|
|
|
|
|
|
if kb_id is not None:
|
|
|
|
|
|
retriever = KnowledgeBaseRetriever(kb_manager=self.kb_manager)
|
|
|
|
|
|
# 预绑定 kb_id 到工具
|
|
|
|
|
|
tools.append(retriever)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("tools_created", tool_count=len(tools), kb_id=kb_id)
|
|
|
|
|
|
return tools
|
|
|
|
|
|
|
|
|
|
|
|
def _create_agent_prompt(self) -> ChatPromptTemplate:
|
|
|
|
|
|
"""
|
|
|
|
|
|
创建代理提示模板。
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
ChatPromptTemplate: 代理提示模板
|
|
|
|
|
|
"""
|
|
|
|
|
|
system_message = """You are a helpful AI assistant with access to various tools.
|
|
|
|
|
|
|
|
|
|
|
|
Available tools:
|
|
|
|
|
|
- calculator: For mathematical calculations
|
|
|
|
|
|
- knowledge_base_search: For searching stored documents (if knowledge base is provided)
|
|
|
|
|
|
|
|
|
|
|
|
Always think step by step and use tools when appropriate.
|
|
|
|
|
|
When using the knowledge_base_search tool, you must provide the kb_id parameter.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages([
|
|
|
|
|
|
("system", system_message),
|
|
|
|
|
|
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
|
|
|
|
|
("human", "{input}"),
|
|
|
|
|
|
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
async def execute_agent(
|
|
|
|
|
|
self,
|
|
|
|
|
|
task: str,
|
|
|
|
|
|
agent_id: str,
|
|
|
|
|
|
kb_id: Optional[int] = None,
|
|
|
|
|
|
llm_name: Optional[str] = None,
|
|
|
|
|
|
) -> Dict[str, Any]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
使用工具执行代理任务。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
task: 任务描述
|
|
|
|
|
|
agent_id: 唯一的代理执行 ID
|
|
|
|
|
|
kb_id: 可选的知识库 ID
|
|
|
|
|
|
llm_name: 可选的 LLM 模型名称
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Dict: 包含输出和工具调用的执行结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
"executing_agent",
|
|
|
|
|
|
agent_id=agent_id,
|
|
|
|
|
|
task=task[:100],
|
|
|
|
|
|
kb_id=kb_id,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 获取 LLM
|
|
|
|
|
|
llm = self.model_manager.get_llm(llm_name)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建工具
|
|
|
|
|
|
tools = self._create_tools(kb_id=kb_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建代理提示
|
|
|
|
|
|
prompt = self._create_agent_prompt()
|
|
|
|
|
|
|
|
|
|
|
|
# 创建代理
|
|
|
|
|
|
agent = create_openai_tools_agent(llm, tools, prompt)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建执行器
|
|
|
|
|
|
agent_executor = AgentExecutor(
|
|
|
|
|
|
agent=agent,
|
|
|
|
|
|
tools=tools,
|
|
|
|
|
|
verbose=True,
|
|
|
|
|
|
max_iterations=10,
|
|
|
|
|
|
handle_parsing_errors=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 执行代理
|
|
|
|
|
|
result = await agent_executor.ainvoke({
|
|
|
|
|
|
"input": task,
|
|
|
|
|
|
"kb_id": kb_id, # 为检索工具传递 kb_id
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
output = result.get("output", "")
|
|
|
|
|
|
|
|
|
|
|
|
# 将工具调用记录到数据库
|
|
|
|
|
|
self._log_tool_calls(agent_id, result)
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("agent_execution_complete", agent_id=agent_id)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
"agent_id": agent_id,
|
|
|
|
|
|
"output": output,
|
|
|
|
|
|
"success": True,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error("agent_execution_failed", agent_id=agent_id, error=str(e))
|
|
|
|
|
|
return {
|
|
|
|
|
|
"agent_id": agent_id,
|
|
|
|
|
|
"output": f"Agent execution failed: {str(e)}",
|
|
|
|
|
|
"success": False,
|
|
|
|
|
|
"error": str(e),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def _log_tool_calls(self, agent_id: str, result: Dict[str, Any]) -> None:
|
|
|
|
|
|
"""
|
|
|
|
|
|
将工具调用记录到数据库。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
agent_id: 代理执行 ID
|
|
|
|
|
|
result: 代理执行结果
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 提取中间步骤(工具调用)
|
|
|
|
|
|
intermediate_steps = result.get("intermediate_steps", [])
|
|
|
|
|
|
|
|
|
|
|
|
for action, observation in intermediate_steps:
|
|
|
|
|
|
# 创建工具调用记录
|
|
|
|
|
|
tool_call = ToolCall(
|
|
|
|
|
|
agent_id=agent_id,
|
|
|
|
|
|
tool_name=action.tool,
|
|
|
|
|
|
call_input={"tool_input": action.tool_input},
|
|
|
|
|
|
call_output={"observation": str(observation)},
|
|
|
|
|
|
created_at=datetime.utcnow(),
|
|
|
|
|
|
)
|
|
|
|
|
|
self.db.add(tool_call)
|
|
|
|
|
|
|
|
|
|
|
|
self.db.commit()
|
|
|
|
|
|
logger.info(
|
|
|
|
|
|
"tool_calls_logged",
|
|
|
|
|
|
agent_id=agent_id,
|
|
|
|
|
|
call_count=len(intermediate_steps),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logger.error("tool_call_logging_failed", agent_id=agent_id, error=str(e))
|
|
|
|
|
|
self.db.rollback()
|
|
|
|
|
|
|
|
|
|
|
|
def get_tool_calls(
|
|
|
|
|
|
self,
|
|
|
|
|
|
agent_id: str,
|
|
|
|
|
|
limit: int = 50,
|
|
|
|
|
|
offset: int = 0,
|
|
|
|
|
|
) -> List[ToolCall]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
获取代理执行的工具调用。
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
agent_id: 代理执行 ID
|
|
|
|
|
|
limit: 最大记录数
|
|
|
|
|
|
offset: 分页偏移量
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
List[ToolCall]: 工具调用记录列表
|
|
|
|
|
|
"""
|
|
|
|
|
|
tool_calls = (
|
|
|
|
|
|
self.db.query(ToolCall)
|
|
|
|
|
|
.filter(ToolCall.agent_id == agent_id)
|
|
|
|
|
|
.order_by(ToolCall.created_at.asc())
|
|
|
|
|
|
.offset(offset)
|
|
|
|
|
|
.limit(limit)
|
|
|
|
|
|
.all()
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return tool_calls
|