281 lines
8.2 KiB
Python
281 lines
8.2 KiB
Python
|
|
"""
|
|||
|
|
对话 API 端点。
|
|||
|
|
"""
|
|||
|
|
from typing import List, Optional
|
|||
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from sqlalchemy.orm import Session
|
|||
|
|
|
|||
|
|
from app.db.session import get_db
|
|||
|
|
from app.services.conv_manager import ConversationManager
|
|||
|
|
from app.services.model_manager import ModelManager
|
|||
|
|
from app.services.kb_manager import KnowledgeBaseManager
|
|||
|
|
from app.services.job_manager import AsyncJobManager
|
|||
|
|
from app.config import get_settings
|
|||
|
|
from app.utils.exceptions import ResourceNotFoundError
|
|||
|
|
from app.utils.logger import get_logger
|
|||
|
|
|
|||
|
|
logger = get_logger(__name__)
|
|||
|
|
router = APIRouter(prefix="/conversations", tags=["conversations"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
# Pydantic 模式
|
|||
|
|
class ConversationCreate(BaseModel):
|
|||
|
|
"""创建对话的模式。"""
|
|||
|
|
user_id: Optional[int] = Field(None, description="User ID")
|
|||
|
|
title: str = Field(default="New Conversation", description="Conversation title")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ConversationResponse(BaseModel):
|
|||
|
|
"""对话响应的模式。"""
|
|||
|
|
id: int
|
|||
|
|
user_id: Optional[int]
|
|||
|
|
title: str
|
|||
|
|
created_at: str
|
|||
|
|
|
|||
|
|
class Config:
|
|||
|
|
from_attributes = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MessageResponse(BaseModel):
|
|||
|
|
"""消息响应的模式。"""
|
|||
|
|
id: int
|
|||
|
|
conversation_id: int
|
|||
|
|
role: str
|
|||
|
|
content: str
|
|||
|
|
msg_metadata: dict
|
|||
|
|
created_at: str
|
|||
|
|
|
|||
|
|
class Config:
|
|||
|
|
from_attributes = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ChatRequest(BaseModel):
|
|||
|
|
"""聊天请求的模式。"""
|
|||
|
|
user_input: str = Field(..., description="User's message")
|
|||
|
|
kb_id: Optional[int] = Field(None, description="Knowledge base ID for RAG")
|
|||
|
|
llm_name: Optional[str] = Field(None, description="LLM model name")
|
|||
|
|
use_rag: bool = Field(default=True, description="Enable RAG (default True)")
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ChatResponse(BaseModel):
|
|||
|
|
"""聊天响应的模式。"""
|
|||
|
|
message_id: int
|
|||
|
|
conversation_id: int
|
|||
|
|
role: str
|
|||
|
|
content: str
|
|||
|
|
metadata: dict
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 依赖项
|
|||
|
|
def get_conv_manager(
|
|||
|
|
db: Session = Depends(get_db),
|
|||
|
|
) -> ConversationManager:
|
|||
|
|
"""获取 ConversationManager 实例。"""
|
|||
|
|
settings = get_settings()
|
|||
|
|
model_manager = ModelManager(db, settings)
|
|||
|
|
job_manager = AsyncJobManager()
|
|||
|
|
kb_manager = KnowledgeBaseManager(db, model_manager, job_manager, settings)
|
|||
|
|
return ConversationManager(db, model_manager, kb_manager, settings)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/", response_model=ConversationResponse, status_code=status.HTTP_201_CREATED)
|
|||
|
|
def create_conversation(
|
|||
|
|
conv_data: ConversationCreate,
|
|||
|
|
conv_manager: ConversationManager = Depends(get_conv_manager),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
创建新对话。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
conv_data: 对话创建数据
|
|||
|
|
conv_manager: ConversationManager 实例
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
创建的对话
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
conversation = conv_manager.create_conversation(
|
|||
|
|
user_id=conv_data.user_id,
|
|||
|
|
title=conv_data.title,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return ConversationResponse(
|
|||
|
|
id=conversation.id,
|
|||
|
|
user_id=conversation.user_id,
|
|||
|
|
title=conversation.title or "New Conversation",
|
|||
|
|
created_at=conversation.created_at.isoformat() if conversation.created_at else "",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("conversation_creation_failed", error=str(e))
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|||
|
|
detail=f"Failed to create conversation: {str(e)}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{conversation_id}", response_model=ConversationResponse)
|
|||
|
|
def get_conversation(
|
|||
|
|
conversation_id: int,
|
|||
|
|
conv_manager: ConversationManager = Depends(get_conv_manager),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
根据 ID 获取对话。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
conversation_id: 对话 ID
|
|||
|
|
conv_manager: ConversationManager 实例
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
对话详情
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
404: 如果找不到对话
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
conversation = conv_manager.get_conversation(conversation_id)
|
|||
|
|
|
|||
|
|
return ConversationResponse(
|
|||
|
|
id=conversation.id,
|
|||
|
|
user_id=conversation.user_id,
|
|||
|
|
title=conversation.title or "New Conversation",
|
|||
|
|
created_at=conversation.created_at.isoformat() if conversation.created_at else "",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except ResourceNotFoundError as e:
|
|||
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("conversation_get_failed", conversation_id=conversation_id, error=str(e))
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|||
|
|
detail=f"Failed to get conversation: {str(e)}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.delete("/{conversation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|||
|
|
def delete_conversation(
|
|||
|
|
conversation_id: int,
|
|||
|
|
conv_manager: ConversationManager = Depends(get_conv_manager),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
删除对话。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
conversation_id: 对话 ID
|
|||
|
|
conv_manager: ConversationManager 实例
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
404: 如果找不到对话
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
conv_manager.delete_conversation(conversation_id)
|
|||
|
|
logger.info("conversation_deleted_via_api", conversation_id=conversation_id)
|
|||
|
|
|
|||
|
|
except ResourceNotFoundError as e:
|
|||
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("conversation_delete_failed", conversation_id=conversation_id, error=str(e))
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|||
|
|
detail=f"Failed to delete conversation: {str(e)}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get("/{conversation_id}/messages", response_model=List[MessageResponse])
|
|||
|
|
def get_messages(
|
|||
|
|
conversation_id: int,
|
|||
|
|
limit: int = 50,
|
|||
|
|
offset: int = 0,
|
|||
|
|
conv_manager: ConversationManager = Depends(get_conv_manager),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
获取对话中的消息。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
conversation_id: 对话 ID
|
|||
|
|
limit: 最大消息数量(默认 50)
|
|||
|
|
offset: 分页偏移量(默认 0)
|
|||
|
|
conv_manager: ConversationManager 实例
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
消息列表
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
404: 如果找不到对话
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
messages = conv_manager.get_messages(
|
|||
|
|
conversation_id=conversation_id,
|
|||
|
|
limit=limit,
|
|||
|
|
offset=offset,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return [
|
|||
|
|
MessageResponse(
|
|||
|
|
id=msg.id,
|
|||
|
|
conversation_id=msg.conversation_id,
|
|||
|
|
role=msg.role,
|
|||
|
|
content=msg.content,
|
|||
|
|
msg_metadata=msg.msg_metadata or {},
|
|||
|
|
created_at=msg.created_at.isoformat() if msg.created_at else "",
|
|||
|
|
)
|
|||
|
|
for msg in messages
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
except ResourceNotFoundError as e:
|
|||
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("messages_get_failed", conversation_id=conversation_id, error=str(e))
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|||
|
|
detail=f"Failed to get messages: {str(e)}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post("/{conversation_id}/chat", response_model=ChatResponse)
|
|||
|
|
async def chat(
|
|||
|
|
conversation_id: int,
|
|||
|
|
chat_data: ChatRequest,
|
|||
|
|
conv_manager: ConversationManager = Depends(get_conv_manager),
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
发送消息并获取助手回复。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
conversation_id: 对话 ID
|
|||
|
|
chat_data: 聊天请求数据
|
|||
|
|
conv_manager: ConversationManager 实例
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
助手的回复
|
|||
|
|
|
|||
|
|
Raises:
|
|||
|
|
404: 如果找不到对话
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
message = await conv_manager.chat(
|
|||
|
|
conversation_id=conversation_id,
|
|||
|
|
user_input=chat_data.user_input,
|
|||
|
|
kb_id=chat_data.kb_id,
|
|||
|
|
llm_name=chat_data.llm_name,
|
|||
|
|
use_rag=chat_data.use_rag,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return ChatResponse(
|
|||
|
|
message_id=message.id,
|
|||
|
|
conversation_id=message.conversation_id,
|
|||
|
|
role=message.role,
|
|||
|
|
content=message.content,
|
|||
|
|
metadata=message.msg_metadata or {},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except ResourceNotFoundError as e:
|
|||
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(e))
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error("chat_failed", conversation_id=conversation_id, error=str(e))
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|||
|
|
detail=f"Failed to process chat: {str(e)}",
|
|||
|
|
)
|