168 lines
4.7 KiB
Python
168 lines
4.7 KiB
Python
"""
|
||
智能体执行 API 端点。
|
||
"""
|
||
from typing import List, Optional
|
||
import uuid
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.db.session import get_db
|
||
from app.services.agent_orchestrator import AgentOrchestrator
|
||
from app.services.model_manager import ModelManager
|
||
from app.services.kb_manager import KnowledgeBaseManager
|
||
from app.services.job_manager import AsyncJobManager
|
||
from app.config import get_settings
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
router = APIRouter(prefix="/agent", tags=["agent"])
|
||
|
||
|
||
# Pydantic 模式
|
||
class AgentExecuteRequest(BaseModel):
|
||
"""智能体执行请求的模式。"""
|
||
task: str = Field(..., description="Task description for the agent")
|
||
kb_id: Optional[int] = Field(None, description="Knowledge base ID for retriever tool")
|
||
llm_name: Optional[str] = Field(None, description="LLM model name")
|
||
|
||
|
||
class AgentExecuteResponse(BaseModel):
|
||
"""智能体执行响应的模式。"""
|
||
agent_id: str
|
||
output: str
|
||
success: bool
|
||
error: Optional[str] = None
|
||
|
||
|
||
class ToolCallResponse(BaseModel):
|
||
"""工具调用响应的模式。"""
|
||
id: int
|
||
agent_id: str
|
||
tool_name: str
|
||
call_input: dict
|
||
call_output: dict
|
||
created_at: str
|
||
|
||
class Config:
|
||
from_attributes = True
|
||
|
||
|
||
class AgentLogsResponse(BaseModel):
|
||
"""智能体日志响应的模式。"""
|
||
agent_id: str
|
||
tool_calls: List[ToolCallResponse]
|
||
total_calls: int
|
||
|
||
|
||
# 依赖项
|
||
def get_agent_orchestrator(
|
||
db: Session = Depends(get_db),
|
||
) -> AgentOrchestrator:
|
||
"""获取 AgentOrchestrator 实例。"""
|
||
settings = get_settings()
|
||
model_manager = ModelManager(db, settings)
|
||
job_manager = AsyncJobManager()
|
||
kb_manager = KnowledgeBaseManager(db, model_manager, job_manager, settings)
|
||
return AgentOrchestrator(db, model_manager, kb_manager, settings)
|
||
|
||
|
||
@router.post("/execute", response_model=AgentExecuteResponse)
|
||
async def execute_agent(
|
||
request: AgentExecuteRequest,
|
||
orchestrator: AgentOrchestrator = Depends(get_agent_orchestrator),
|
||
):
|
||
"""
|
||
使用工具执行智能体任务。
|
||
|
||
智能体可以访问:
|
||
- 计算器工具:用于数学计算
|
||
- 知识库检索工具:用于搜索文档(如果提供了 kb_id)
|
||
|
||
Args:
|
||
request: 智能体执行请求
|
||
orchestrator: AgentOrchestrator 实例
|
||
|
||
Returns:
|
||
带有输出和状态的执行结果
|
||
"""
|
||
try:
|
||
# 生成唯一的智能体 ID
|
||
agent_id = str(uuid.uuid4())
|
||
|
||
logger.info(
|
||
"agent_execute_request",
|
||
agent_id=agent_id,
|
||
task=request.task[:100],
|
||
kb_id=request.kb_id,
|
||
)
|
||
|
||
# 执行智能体
|
||
result = await orchestrator.execute_agent(
|
||
task=request.task,
|
||
agent_id=agent_id,
|
||
kb_id=request.kb_id,
|
||
llm_name=request.llm_name,
|
||
)
|
||
|
||
return AgentExecuteResponse(**result)
|
||
|
||
except Exception as e:
|
||
logger.error("agent_execute_api_failed", error=str(e))
|
||
return AgentExecuteResponse(
|
||
agent_id=str(uuid.uuid4()),
|
||
output="",
|
||
success=False,
|
||
error=f"Agent execution failed: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.get("/logs/{agent_id}", response_model=AgentLogsResponse)
|
||
def get_agent_logs(
|
||
agent_id: str,
|
||
limit: int = 50,
|
||
offset: int = 0,
|
||
orchestrator: AgentOrchestrator = Depends(get_agent_orchestrator),
|
||
):
|
||
"""
|
||
获取智能体执行的工具调用日志。
|
||
|
||
Args:
|
||
agent_id: 智能体执行 ID
|
||
limit: 最大记录数(默认 50)
|
||
offset: 分页偏移量(默认 0)
|
||
orchestrator: AgentOrchestrator 实例
|
||
|
||
Returns:
|
||
包含工具调用的智能体执行日志
|
||
"""
|
||
try:
|
||
tool_calls = orchestrator.get_tool_calls(
|
||
agent_id=agent_id,
|
||
limit=limit,
|
||
offset=offset,
|
||
)
|
||
|
||
return AgentLogsResponse(
|
||
agent_id=agent_id,
|
||
tool_calls=[
|
||
ToolCallResponse(
|
||
id=tc.id,
|
||
agent_id=tc.agent_id,
|
||
tool_name=tc.tool_name,
|
||
call_input=tc.call_input or {},
|
||
call_output=tc.call_output or {},
|
||
created_at=tc.created_at.isoformat() if tc.created_at else "",
|
||
)
|
||
for tc in tool_calls
|
||
],
|
||
total_calls=len(tool_calls),
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error("agent_logs_failed", agent_id=agent_id, error=str(e))
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to get agent logs: {str(e)}",
|
||
)
|