168 lines
4.7 KiB
Python
168 lines
4.7 KiB
Python
|
|
"""
|
|||
|
|
智能体执行 API 端点。
|
|||
|
|
"""
|
|||
|
|
from typing import List, Optional
|
|||
|
|
import uuid
|
|||
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.db.session import get_db
|
|||
|
|
from app.services.agent_orchestrator import AgentOrchestrator
|
|||
|
|
from app.services.model_manager import ModelManager
|
|||
|
|
from app.services.kb_manager import KnowledgeBaseManager
|
|||
|
|
from app.services.job_manager import AsyncJobManager
|
|||
|
|
from app.config import get_settings
|
|||
|
|
from app.utils.logger import get_logger
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
router = APIRouter(prefix="/agent", tags=["agent"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Pydantic 模式
|
|||
|
|
class AgentExecuteRequest(BaseModel):
|
|||
|
|
"""智能体执行请求的模式。"""
|
|||
|
|
task: str = Field(..., description="Task description for the agent")
|
|||
|
|
kb_id: Optional[int] = Field(None, description="Knowledge base ID for retriever tool")
|
|||
|
|
llm_name: Optional[str] = Field(None, description="LLM model name")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AgentExecuteResponse(BaseModel):
|
|||
|
|
"""智能体执行响应的模式。"""
|
|||
|
|
agent_id: str
|
|||
|
|
output: str
|
|||
|
|
success: bool
|
|||
|
|
error: Optional[str] = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ToolCallResponse(BaseModel):
|
|||
|
|
"""工具调用响应的模式。"""
|
|||
|
|
id: int
|
|||
|
|
agent_id: str
|
|||
|
|
tool_name: str
|
|||
|
|
call_input: dict
|
|||
|
|
call_output: dict
|
|||
|
|
created_at: str
|
|||
|
|
|
|||
|
|
class Config:
|
|||
|
|
from_attributes = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AgentLogsResponse(BaseModel):
|
|||
|
|
"""智能体日志响应的模式。"""
|
|||
|
|
agent_id: str
|
|||
|
|
tool_calls: List[ToolCallResponse]
|
|||
|
|
total_calls: int
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 依赖项
|
|||
|
|
def get_agent_orchestrator(
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
) -> AgentOrchestrator:
|
|||
|
|
"""获取 AgentOrchestrator 实例。"""
|
|||
|
|
settings = get_settings()
|
|||
|
|
model_manager = ModelManager(db, settings)
|
|||
|
|
job_manager = AsyncJobManager()
|
|||
|
|
kb_manager = KnowledgeBaseManager(db, model_manager, job_manager, settings)
|
|||
|
|
return AgentOrchestrator(db, model_manager, kb_manager, settings)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/execute", response_model=AgentExecuteResponse)
|
|||
|
|
async def execute_agent(
|
|||
|
|
request: AgentExecuteRequest,
|
|||
|
|
orchestrator: AgentOrchestrator = Depends(get_agent_orchestrator),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
使用工具执行智能体任务。
|
|||
|
|
|
|||
|
|
智能体可以访问:
|
|||
|
|
- 计算器工具:用于数学计算
|
|||
|
|
- 知识库检索工具:用于搜索文档(如果提供了 kb_id)
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
request: 智能体执行请求
|
|||
|
|
orchestrator: AgentOrchestrator 实例
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
带有输出和状态的执行结果
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 生成唯一的智能体 ID
|
|||
|
|
agent_id = str(uuid.uuid4())
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
"agent_execute_request",
|
|||
|
|
agent_id=agent_id,
|
|||
|
|
task=request.task[:100],
|
|||
|
|
kb_id=request.kb_id,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 执行智能体
|
|||
|
|
result = await orchestrator.execute_agent(
|
|||
|
|
task=request.task,
|
|||
|
|
agent_id=agent_id,
|
|||
|
|
kb_id=request.kb_id,
|
|||
|
|
llm_name=request.llm_name,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return AgentExecuteResponse(**result)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("agent_execute_api_failed", error=str(e))
|
|||
|
|
return AgentExecuteResponse(
|
|||
|
|
agent_id=str(uuid.uuid4()),
|
|||
|
|
output="",
|
|||
|
|
success=False,
|
|||
|
|
error=f"Agent execution failed: {str(e)}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/logs/{agent_id}", response_model=AgentLogsResponse)
|
|||
|
|
def get_agent_logs(
|
|||
|
|
agent_id: str,
|
|||
|
|
limit: int = 50,
|
|||
|
|
offset: int = 0,
|
|||
|
|
orchestrator: AgentOrchestrator = Depends(get_agent_orchestrator),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
获取智能体执行的工具调用日志。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
agent_id: 智能体执行 ID
|
|||
|
|
limit: 最大记录数(默认 50)
|
|||
|
|
offset: 分页偏移量(默认 0)
|
|||
|
|
orchestrator: AgentOrchestrator 实例
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
包含工具调用的智能体执行日志
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
tool_calls = orchestrator.get_tool_calls(
|
|||
|
|
agent_id=agent_id,
|
|||
|
|
limit=limit,
|
|||
|
|
offset=offset,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return AgentLogsResponse(
|
|||
|
|
agent_id=agent_id,
|
|||
|
|
tool_calls=[
|
|||
|
|
ToolCallResponse(
|
|||
|
|
id=tc.id,
|
|||
|
|
agent_id=tc.agent_id,
|
|||
|
|
tool_name=tc.tool_name,
|
|||
|
|
call_input=tc.call_input or {},
|
|||
|
|
call_output=tc.call_output or {},
|
|||
|
|
created_at=tc.created_at.isoformat() if tc.created_at else "",
|
|||
|
|
)
|
|||
|
|
for tc in tool_calls
|
|||
|
|
],
|
|||
|
|
total_calls=len(tool_calls),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("agent_logs_failed", agent_id=agent_id, error=str(e))
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|||
|
|
detail=f"Failed to get agent logs: {str(e)}",
|
|||
|
|
)
|