langchain-learning-kit/app/api/agent.py

168 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
智能体执行 API 端点。
"""
from typing import List, Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.services.agent_orchestrator import AgentOrchestrator
from app.services.model_manager import ModelManager
from app.services.kb_manager import KnowledgeBaseManager
from app.services.job_manager import AsyncJobManager
from app.config import get_settings
from app.utils.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/agent", tags=["agent"])
# Pydantic 模式
class AgentExecuteRequest(BaseModel):
"""智能体执行请求的模式。"""
task: str = Field(..., description="Task description for the agent")
kb_id: Optional[int] = Field(None, description="Knowledge base ID for retriever tool")
llm_name: Optional[str] = Field(None, description="LLM model name")
class AgentExecuteResponse(BaseModel):
"""智能体执行响应的模式。"""
agent_id: str
output: str
success: bool
error: Optional[str] = None
class ToolCallResponse(BaseModel):
"""工具调用响应的模式。"""
id: int
agent_id: str
tool_name: str
call_input: dict
call_output: dict
created_at: str
class Config:
from_attributes = True
class AgentLogsResponse(BaseModel):
"""智能体日志响应的模式。"""
agent_id: str
tool_calls: List[ToolCallResponse]
total_calls: int
# 依赖项
def get_agent_orchestrator(
db: Session = Depends(get_db),
) -> AgentOrchestrator:
"""获取 AgentOrchestrator 实例。"""
settings = get_settings()
model_manager = ModelManager(db, settings)
job_manager = AsyncJobManager()
kb_manager = KnowledgeBaseManager(db, model_manager, job_manager, settings)
return AgentOrchestrator(db, model_manager, kb_manager, settings)
@router.post("/execute", response_model=AgentExecuteResponse)
async def execute_agent(
request: AgentExecuteRequest,
orchestrator: AgentOrchestrator = Depends(get_agent_orchestrator),
):
"""
使用工具执行智能体任务。
智能体可以访问:
- 计算器工具:用于数学计算
- 知识库检索工具:用于搜索文档(如果提供了 kb_id
Args:
request: 智能体执行请求
orchestrator: AgentOrchestrator 实例
Returns:
带有输出和状态的执行结果
"""
try:
# 生成唯一的智能体 ID
agent_id = str(uuid.uuid4())
logger.info(
"agent_execute_request",
agent_id=agent_id,
task=request.task[:100],
kb_id=request.kb_id,
)
# 执行智能体
result = await orchestrator.execute_agent(
task=request.task,
agent_id=agent_id,
kb_id=request.kb_id,
llm_name=request.llm_name,
)
return AgentExecuteResponse(**result)
except Exception as e:
logger.error("agent_execute_api_failed", error=str(e))
return AgentExecuteResponse(
agent_id=str(uuid.uuid4()),
output="",
success=False,
error=f"Agent execution failed: {str(e)}",
)
@router.get("/logs/{agent_id}", response_model=AgentLogsResponse)
def get_agent_logs(
agent_id: str,
limit: int = 50,
offset: int = 0,
orchestrator: AgentOrchestrator = Depends(get_agent_orchestrator),
):
"""
获取智能体执行的工具调用日志。
Args:
agent_id: 智能体执行 ID
limit: 最大记录数(默认 50
offset: 分页偏移量(默认 0
orchestrator: AgentOrchestrator 实例
Returns:
包含工具调用的智能体执行日志
"""
try:
tool_calls = orchestrator.get_tool_calls(
agent_id=agent_id,
limit=limit,
offset=offset,
)
return AgentLogsResponse(
agent_id=agent_id,
tool_calls=[
ToolCallResponse(
id=tc.id,
agent_id=tc.agent_id,
tool_name=tc.tool_name,
call_input=tc.call_input or {},
call_output=tc.call_output or {},
created_at=tc.created_at.isoformat() if tc.created_at else "",
)
for tc in tool_calls
],
total_calls=len(tool_calls),
)
except Exception as e:
logger.error("agent_logs_failed", agent_id=agent_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get agent logs: {str(e)}",
)