langchain-learning-kit/app/api/conv.py

281 lines
8.2 KiB
Python
Raw Permalink Normal View History

"""
对话 API 端点
"""
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.services.conv_manager import ConversationManager
from app.services.model_manager import ModelManager
from app.services.kb_manager import KnowledgeBaseManager
from app.services.job_manager import AsyncJobManager
from app.config import get_settings
from app.utils.exceptions import ResourceNotFoundError
from app.utils.logger import get_logger
logger = get_logger(__name__)
router = APIRouter(prefix="/conversations", tags=["conversations"])
# Pydantic 模式
class ConversationCreate(BaseModel):
"""创建对话的模式。"""
user_id: Optional[int] = Field(None, description="User ID")
title: str = Field(default="New Conversation", description="Conversation title")
class ConversationResponse(BaseModel):
"""对话响应的模式。"""
id: int
user_id: Optional[int]
title: str
created_at: str
class Config:
from_attributes = True
class MessageResponse(BaseModel):
"""消息响应的模式。"""
id: int
conversation_id: int
role: str
content: str
msg_metadata: dict
created_at: str
class Config:
from_attributes = True
class ChatRequest(BaseModel):
"""聊天请求的模式。"""
user_input: str = Field(..., description="User's message")
kb_id: Optional[int] = Field(None, description="Knowledge base ID for RAG")
llm_name: Optional[str] = Field(None, description="LLM model name")
use_rag: bool = Field(default=True, description="Enable RAG (default True)")
class ChatResponse(BaseModel):
"""聊天响应的模式。"""
message_id: int
conversation_id: int
role: str
content: str
metadata: dict
# 依赖项
def get_conv_manager(
db: Session = Depends(get_db),
) -> ConversationManager:
"""获取 ConversationManager 实例。"""
settings = get_settings()
model_manager = ModelManager(db, settings)
job_manager = AsyncJobManager()
kb_manager = KnowledgeBaseManager(db, model_manager, job_manager, settings)
return ConversationManager(db, model_manager, kb_manager, settings)
@router.post("/", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
def create_conversation(
conv_data: ConversationCreate,
conv_manager: ConversationManager = Depends(get_conv_manager),
):
"""
创建新对话
Args:
conv_data: 对话创建数据
conv_manager: ConversationManager 实例
Returns:
创建的对话
"""
try:
conversation = conv_manager.create_conversation(
user_id=conv_data.user_id,
title=conv_data.title,
)
return ConversationResponse(
id=conversation.id,
user_id=conversation.user_id,
title=conversation.title or "New Conversation",
created_at=conversation.created_at.isoformat() if conversation.created_at else "",
)
except Exception as e:
logger.error("conversation_creation_failed", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create conversation: {str(e)}",
)
@router.get("/{conversation_id}", response_model=ConversationResponse)
def get_conversation(
conversation_id: int,
conv_manager: ConversationManager = Depends(get_conv_manager),
):
"""
根据 ID 获取对话
Args:
conversation_id: 对话 ID
conv_manager: ConversationManager 实例
Returns:
对话详情
Raises:
404: 如果找不到对话
"""
try:
conversation = conv_manager.get_conversation(conversation_id)
return ConversationResponse(
id=conversation.id,
user_id=conversation.user_id,
title=conversation.title or "New Conversation",
created_at=conversation.created_at.isoformat() if conversation.created_at else "",
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("conversation_get_failed", conversation_id=conversation_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get conversation: {str(e)}",
)
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_conversation(
conversation_id: int,
conv_manager: ConversationManager = Depends(get_conv_manager),
):
"""
删除对话
Args:
conversation_id: 对话 ID
conv_manager: ConversationManager 实例
Raises:
404: 如果找不到对话
"""
try:
conv_manager.delete_conversation(conversation_id)
logger.info("conversation_deleted_via_api", conversation_id=conversation_id)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("conversation_delete_failed", conversation_id=conversation_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete conversation: {str(e)}",
)
@router.get("/{conversation_id}/messages", response_model=List[MessageResponse])
def get_messages(
conversation_id: int,
limit: int = 50,
offset: int = 0,
conv_manager: ConversationManager = Depends(get_conv_manager),
):
"""
获取对话中的消息
Args:
conversation_id: 对话 ID
limit: 最大消息数量默认 50
offset: 分页偏移量默认 0
conv_manager: ConversationManager 实例
Returns:
消息列表
Raises:
404: 如果找不到对话
"""
try:
messages = conv_manager.get_messages(
conversation_id=conversation_id,
limit=limit,
offset=offset,
)
return [
MessageResponse(
id=msg.id,
conversation_id=msg.conversation_id,
role=msg.role,
content=msg.content,
msg_metadata=msg.msg_metadata or {},
created_at=msg.created_at.isoformat() if msg.created_at else "",
)
for msg in messages
]
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("messages_get_failed", conversation_id=conversation_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get messages: {str(e)}",
)
@router.post("/{conversation_id}/chat", response_model=ChatResponse)
async def chat(
conversation_id: int,
chat_data: ChatRequest,
conv_manager: ConversationManager = Depends(get_conv_manager),
):
"""
发送消息并获取助手回复
Args:
conversation_id: 对话 ID
chat_data: 聊天请求数据
conv_manager: ConversationManager 实例
Returns:
助手的回复
Raises:
404: 如果找不到对话
"""
try:
message = await conv_manager.chat(
conversation_id=conversation_id,
user_input=chat_data.user_input,
kb_id=chat_data.kb_id,
llm_name=chat_data.llm_name,
use_rag=chat_data.use_rag,
)
return ChatResponse(
message_id=message.id,
conversation_id=message.conversation_id,
role=message.role,
content=message.content,
metadata=message.msg_metadata or {},
)
except ResourceNotFoundError as e:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
except Exception as e:
logger.error("chat_failed", conversation_id=conversation_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to process chat: {str(e)}",
)