281 lines
8.2 KiB
Python
281 lines
8.2 KiB
Python
"""
|
||
对话 API 端点。
|
||
"""
|
||
from typing import List, Optional
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.db.session import get_db
|
||
from app.services.conv_manager import ConversationManager
|
||
from app.services.model_manager import ModelManager
|
||
from app.services.kb_manager import KnowledgeBaseManager
|
||
from app.services.job_manager import AsyncJobManager
|
||
from app.config import get_settings
|
||
from app.utils.exceptions import ResourceNotFoundError
|
||
from app.utils.logger import get_logger
|
||
|
||
logger = get_logger(__name__)
|
||
router = APIRouter(prefix="/conversations", tags=["conversations"])
|
||
|
||
|
||
# Pydantic 模式
|
||
class ConversationCreate(BaseModel):
|
||
"""创建对话的模式。"""
|
||
user_id: Optional[int] = Field(None, description="User ID")
|
||
title: str = Field(default="New Conversation", description="Conversation title")
|
||
|
||
|
||
class ConversationResponse(BaseModel):
|
||
"""对话响应的模式。"""
|
||
id: int
|
||
user_id: Optional[int]
|
||
title: str
|
||
created_at: str
|
||
|
||
class Config:
|
||
from_attributes = True
|
||
|
||
|
||
class MessageResponse(BaseModel):
|
||
"""消息响应的模式。"""
|
||
id: int
|
||
conversation_id: int
|
||
role: str
|
||
content: str
|
||
msg_metadata: dict
|
||
created_at: str
|
||
|
||
class Config:
|
||
from_attributes = True
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
"""聊天请求的模式。"""
|
||
user_input: str = Field(..., description="User's message")
|
||
kb_id: Optional[int] = Field(None, description="Knowledge base ID for RAG")
|
||
llm_name: Optional[str] = Field(None, description="LLM model name")
|
||
use_rag: bool = Field(default=True, description="Enable RAG (default True)")
|
||
|
||
|
||
class ChatResponse(BaseModel):
|
||
"""聊天响应的模式。"""
|
||
message_id: int
|
||
conversation_id: int
|
||
role: str
|
||
content: str
|
||
metadata: dict
|
||
|
||
|
||
# 依赖项
|
||
def get_conv_manager(
|
||
db: Session = Depends(get_db),
|
||
) -> ConversationManager:
|
||
"""获取 ConversationManager 实例。"""
|
||
settings = get_settings()
|
||
model_manager = ModelManager(db, settings)
|
||
job_manager = AsyncJobManager()
|
||
kb_manager = KnowledgeBaseManager(db, model_manager, job_manager, settings)
|
||
return ConversationManager(db, model_manager, kb_manager, settings)
|
||
|
||
|
||
@router.post("/", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
|
||
def create_conversation(
|
||
conv_data: ConversationCreate,
|
||
conv_manager: ConversationManager = Depends(get_conv_manager),
|
||
):
|
||
"""
|
||
创建新对话。
|
||
|
||
Args:
|
||
conv_data: 对话创建数据
|
||
conv_manager: ConversationManager 实例
|
||
|
||
Returns:
|
||
创建的对话
|
||
"""
|
||
try:
|
||
conversation = conv_manager.create_conversation(
|
||
user_id=conv_data.user_id,
|
||
title=conv_data.title,
|
||
)
|
||
|
||
return ConversationResponse(
|
||
id=conversation.id,
|
||
user_id=conversation.user_id,
|
||
title=conversation.title or "New Conversation",
|
||
created_at=conversation.created_at.isoformat() if conversation.created_at else "",
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error("conversation_creation_failed", error=str(e))
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to create conversation: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.get("/{conversation_id}", response_model=ConversationResponse)
|
||
def get_conversation(
|
||
conversation_id: int,
|
||
conv_manager: ConversationManager = Depends(get_conv_manager),
|
||
):
|
||
"""
|
||
根据 ID 获取对话。
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
conv_manager: ConversationManager 实例
|
||
|
||
Returns:
|
||
对话详情
|
||
|
||
Raises:
|
||
404: 如果找不到对话
|
||
"""
|
||
try:
|
||
conversation = conv_manager.get_conversation(conversation_id)
|
||
|
||
return ConversationResponse(
|
||
id=conversation.id,
|
||
user_id=conversation.user_id,
|
||
title=conversation.title or "New Conversation",
|
||
created_at=conversation.created_at.isoformat() if conversation.created_at else "",
|
||
)
|
||
|
||
except ResourceNotFoundError as e:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
except Exception as e:
|
||
logger.error("conversation_get_failed", conversation_id=conversation_id, error=str(e))
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to get conversation: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||
def delete_conversation(
|
||
conversation_id: int,
|
||
conv_manager: ConversationManager = Depends(get_conv_manager),
|
||
):
|
||
"""
|
||
删除对话。
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
conv_manager: ConversationManager 实例
|
||
|
||
Raises:
|
||
404: 如果找不到对话
|
||
"""
|
||
try:
|
||
conv_manager.delete_conversation(conversation_id)
|
||
logger.info("conversation_deleted_via_api", conversation_id=conversation_id)
|
||
|
||
except ResourceNotFoundError as e:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
except Exception as e:
|
||
logger.error("conversation_delete_failed", conversation_id=conversation_id, error=str(e))
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to delete conversation: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.get("/{conversation_id}/messages", response_model=List[MessageResponse])
|
||
def get_messages(
|
||
conversation_id: int,
|
||
limit: int = 50,
|
||
offset: int = 0,
|
||
conv_manager: ConversationManager = Depends(get_conv_manager),
|
||
):
|
||
"""
|
||
获取对话中的消息。
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
limit: 最大消息数量(默认 50)
|
||
offset: 分页偏移量(默认 0)
|
||
conv_manager: ConversationManager 实例
|
||
|
||
Returns:
|
||
消息列表
|
||
|
||
Raises:
|
||
404: 如果找不到对话
|
||
"""
|
||
try:
|
||
messages = conv_manager.get_messages(
|
||
conversation_id=conversation_id,
|
||
limit=limit,
|
||
offset=offset,
|
||
)
|
||
|
||
return [
|
||
MessageResponse(
|
||
id=msg.id,
|
||
conversation_id=msg.conversation_id,
|
||
role=msg.role,
|
||
content=msg.content,
|
||
msg_metadata=msg.msg_metadata or {},
|
||
created_at=msg.created_at.isoformat() if msg.created_at else "",
|
||
)
|
||
for msg in messages
|
||
]
|
||
|
||
except ResourceNotFoundError as e:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
except Exception as e:
|
||
logger.error("messages_get_failed", conversation_id=conversation_id, error=str(e))
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to get messages: {str(e)}",
|
||
)
|
||
|
||
|
||
@router.post("/{conversation_id}/chat", response_model=ChatResponse)
|
||
async def chat(
|
||
conversation_id: int,
|
||
chat_data: ChatRequest,
|
||
conv_manager: ConversationManager = Depends(get_conv_manager),
|
||
):
|
||
"""
|
||
发送消息并获取助手回复。
|
||
|
||
Args:
|
||
conversation_id: 对话 ID
|
||
chat_data: 聊天请求数据
|
||
conv_manager: ConversationManager 实例
|
||
|
||
Returns:
|
||
助手的回复
|
||
|
||
Raises:
|
||
404: 如果找不到对话
|
||
"""
|
||
try:
|
||
message = await conv_manager.chat(
|
||
conversation_id=conversation_id,
|
||
user_input=chat_data.user_input,
|
||
kb_id=chat_data.kb_id,
|
||
llm_name=chat_data.llm_name,
|
||
use_rag=chat_data.use_rag,
|
||
)
|
||
|
||
return ChatResponse(
|
||
message_id=message.id,
|
||
conversation_id=message.conversation_id,
|
||
role=message.role,
|
||
content=message.content,
|
||
metadata=message.msg_metadata or {},
|
||
)
|
||
|
||
except ResourceNotFoundError as e:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
||
except Exception as e:
|
||
logger.error("chat_failed", conversation_id=conversation_id, error=str(e))
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"Failed to process chat: {str(e)}",
|
||
)
|